diff --git a/applications/flash_attention_v2/collective/fmha_fusion.hpp b/applications/flash_attention_v2/collective/fmha_fusion.hpp index a87752588f..d943228538 100644 --- a/applications/flash_attention_v2/collective/fmha_fusion.hpp +++ b/applications/flash_attention_v2/collective/fmha_fusion.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** -* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -39,6 +40,7 @@ using namespace cute; struct VariableLength { int max_length; + int total_length = 0; int* cumulative_length = nullptr; CUTE_HOST_DEVICE operator int() const { diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp new file mode 100644 index 0000000000..dfaf883274 --- /dev/null +++ b/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp @@ -0,0 +1,255 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/detail/layout.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace flash_attention { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class FlashChunkPrefillEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + +template +class FlashChunkPrefillEpilogue { +public: + // + // Type Aliases + // + using DispatchPolicy = epilogue::IntelXeXMX16; + using ElementO = ElementO_; + using StrideO = StrideO_; + using ElementLSE = ElementLSE_; + using CopyOpO = CopyOpO_; + using SubgroupLayout = SubgroupLayout_; + using TileShapeOutput = TileShapeOutput_; + using TiledMmaOutput = typename TiledMMAHelper, Layout, SubgroupLayout>::TiledMMA; + using GmemTiledCopyO = CopyOpO; + using ElementOutput = ElementO_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementCompute_; + using SubgroupTileShape = decltype(cute::shape_div(TileShapeOutput{}, (SubgroupLayout{}.shape()))); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert(cute::rank(TileShapeOutput{}) == 3, "TileShapeOutput must be rank-3: [CTA_M_QO, CTA_N_VO, CTA_K_PV]"); + static_assert(cute::rank(StrideO{}) == 3, "StrideO must be rank-3: [seq_len_qo, head_size_vo, batch * num_heads]"); + + using CopyThreadShape = Shape<_1, Int>; + + using traits_store_O = Copy_Traits; + using atom_load_O = Copy_Atom; + using val_layout_load_O = decltype(make_layout(shape_div(typename traits_store_O::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_O = decltype(make_tiled_copy(atom_load_O{}, Layout{}, val_layout_load_O{})); + +private: + constexpr static bool is_destination_supported = not cute::is_void_v; + +public: + using EmptyType = cute::tuple<>; + + struct TensorStorageImpl : cute::tuple {}; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Host side epilogue arguments + struct Arguments { + ElementO const *ptr_O; + StrideO dO; + }; + + // Device side epilogue params + struct Params { + XE_Copy_O xe_store_o; + }; + + // + // Methods + // + template + CUTLASS_DEVICE auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast *>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + } + + template + static constexpr Params to_underlying_arguments(ProblemShape const &problem_shape, Arguments const &args, + [[maybe_unused]] void *workspace) { + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; + auto tensorO = make_tensor(make_gmem_ptr(static_cast(args.ptr_O)), + make_layout(make_shape(seq_len_qo, num_heads_q * head_size_vo, batch), + args.dO)); + XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)}; + return { + xe_store_o, + }; + } + + template + static size_t get_workspace_size(ProblemShape const &problem_shape, Arguments const &args) { + return 0; + } + + template + static cutlass::Status initialize_workspace(ProblemShape const &problem_shape, Arguments const &args, void *workspace, + cudaStream_t stream, CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement(ProblemShape const &problem_shape, + [[maybe_unused]] Arguments const &args) { + return true; + } + + CUTLASS_HOST_DEVICE + FlashChunkPrefillEpilogue(Params const ¶ms_, TensorStorage const &) : params(params_) {} + + template + CUTLASS_DEVICE void operator()(ProblemShape problem_shape, SequenceLengthShape sequence_length_shape, TileCoord tile_coord, FragOut &out, + FragMax const &max, FragSum &sum) { + + using namespace cute; + + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v>; + + using FragOutLayout = typename FragOut::layout_type; + + constexpr int Vec = shape<0>(FragOutLayout{}); + constexpr int FragsM = shape<1>(FragOutLayout{}); + constexpr int FragsN = size(select<2,3>(shape(FragOutLayout{}))); + + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto out_reg = make_tensor(static_cast(out).data() , Shape, Int, Int>{}); + + CUTLASS_PRAGMA_UNROLL + for (int y = 0; y < FragsM; y++) { + CUTLASS_PRAGMA_UNROLL + for (int x = 0; x < Vec; x++) { + int indx = y * Vec + x; + auto cur_sum = reduce_over_group(sg, sum(indx), sycl::plus<>()); + auto cur_scale = (cur_sum == 0.f || cur_sum != cur_sum) ? 1.0f : sycl::native::recip(cur_sum); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + out_reg(x, y, z) *= cur_scale; + } + } + } + + // Indexing variables + auto [batch, num_heads_q, num_heads_kv, head_size_vo] = select<0, 1, 2, 7>(problem_shape); + auto [seq_len_qo] = select<0>(sequence_length_shape); + // Represent the full output tensor + Tensor mO_mnl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_vo, 1)); + + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord; + // Tile the output tensor per WG + Tensor g_wg_O = local_tile(mO_mnl, select<0,1>(TileShapeOutput{}), make_coord(m_coord,n_coord,0)); // (BLK_M,BLK_N,m,n,l) + static constexpr auto ATOM_N = get<2>(typename TiledMmaOutput::ThrLayoutVMNK{}.shape()); + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + // Tile the output tensor per SG + Tensor gO = local_tile(g_wg_O, SubgroupTileShape{}, make_coord(m_sg,n_sg,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX()); + Tensor tOgO = thread_xe_store_o.partition_D(gO); + + Tensor final_out_reg = make_fragment_like(out_reg); + // iff ElementOutput == ElementAccumulator, then convert_type doesn't do the right conversion + // so we call copy() which internally performs a static_cast op on the data. + // for ElementOutput == bf16 | fp16, convert_type calls relevant NumericConverter specialization. + if constexpr (cute::is_same_v) { + copy(out_reg, final_out_reg); + } else { + Tensor temp = convert_type(out_reg); + copy(temp, final_out_reg); + } + copy(params.xe_store_o, final_out_reg, tOgO); + } + + // SequenceLengthShapeType = Shape + // For Fixed Sequence Length, ProblemShapeType = Shape + // For Variable Sequence Length, ProblemShapeType = Shape + template + CUTLASS_DEVICE static constexpr Params get_updated_copies(Params const& params, ProblemShapeType const& problem_shape, + SequenceLengthShapeType const& sequence_length_shape, int const& l_coord, int const& q_head_coord) { + auto [num_heads_q, num_heads_kv, head_size_vo] = select<1, 2, 7>(problem_shape); + auto [seq_len_qo] = select<0>(sequence_length_shape); + int offset_o = 0; + if constexpr (VarLen) { + auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; + offset_o = num_heads_q * head_size_vo * qo_cumulative_length[l_coord] + q_head_coord * head_size_vo; + } else { + offset_o = num_heads_q * head_size_vo * seq_len_qo * l_coord + q_head_coord * head_size_vo; + } + auto store_traits = static_cast(params.xe_store_o); + ElementO* base_ptr = (ElementO*)store_traits.base_ptr; + auto shape_o = make_shape(static_cast(seq_len_qo), num_heads_q * head_size_vo, 1); + StrideO stride_o = cutlass::make_cute_packed_stride(StrideO{}, shape_o); + auto tensorO = make_tensor(make_gmem_ptr(base_ptr + offset_o), make_layout(shape_o, stride_o)); + XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)}; + return Params{xe_store_o}; + } + +private: + Params const ¶ms; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace flash_attention +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_mma.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_mma.hpp new file mode 100644 index 0000000000..ded01bae3b --- /dev/null +++ b/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_mma.hpp @@ -0,0 +1,601 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/fp8_to_fp16.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "fmha_fusion.hpp" + + +//////////////////////////////////////////////////////////// +namespace { + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::flash_attention::collective { +using namespace cute; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FlashChunkPrefillMma { + static_assert(cutlass::detail::dependent_false, + "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FlashChunkPrefillMma< + gemm::MainloopIntelXeXMX16, ProblemShapeType_, ElementQ_, StrideQ_, + ElementK_, StrideK_, ElementV_, StrideV_, MMAOperation_, TileShapeQK_, + TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, + GmemTiledCopyV_, CausalMask_, LocalMask_, PagedKV_> { + // + // Type Aliases + // + using DispatchPolicy = gemm::MainloopIntelXeXMX16; + using TileShapeQK = TileShapeQK_; + using TileShapePV = TileShapePV_; + using SubgroupLayout = SubgroupLayout_; + using ProblemShapeType = ProblemShapeType_; + using ElementQ = ElementQ_; + using StrideQ = StrideQ_; + using ElementK = ElementK_; + using StrideK = StrideK_; + using ElementV = ElementV_; + using StrideV = StrideV_; + using GmemTiledCopyQ = GmemTiledCopyQ_; + using GmemTiledCopyK = GmemTiledCopyK_; + using GmemTiledCopyV = GmemTiledCopyV_; + using ArchTag = typename DispatchPolicy::ArchTag; + using MmaAtom = MMA_Atom; + + using TiledMmaQK = typename TiledMMAHelper, + SubgroupLayout>::TiledMMA; + + using TiledMmaPV = typename TiledMMAHelper, + SubgroupLayout>::TiledMMA; + using ElementAccumulator = typename TiledMmaQK::ValTypeC; + static constexpr bool CausalMask = CausalMask_; + static constexpr bool LocalMask = LocalMask_; + static constexpr bool PagedKV = PagedKV_; + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename MmaAtom::Shape_MNK; + + static constexpr auto PV_ATOM_M = + decltype(get<0>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_N = + decltype(get<1>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_K = + decltype(get<2>(SubgroupLayout{}.shape()))::value; + + using SubgroupTileShapePV = + decltype(cute::shape_div(TileShapePV{}, (SubgroupLayout{}.shape()))); + static constexpr auto QK_BLK_M = get<0>(TileShapeQK{}); + static constexpr auto QK_BLK_N = get<1>(TileShapeQK{}); + static constexpr auto QK_BLK_K = get<2>(TileShapeQK{}); + + // This TiledMma is only required to serve the specific tiling requirements + // for matrix K. This is due to the consumption of matrix K by all subgroups + // within a workgroup. + static constexpr auto QK_ATOM_M = PV_ATOM_M; // 8 + static constexpr auto QK_ATOM_N = PV_ATOM_N; // 1 + static constexpr auto QK_ATOM_K = PV_ATOM_K; // 1 + + using SubgroupTileShapeQK = decltype(cute::shape_div( + TileShapeQK{}, + SubgroupLayout{}.shape())); // 128, 64, 32 / 16, 1, 1 = (8, 64, 32 ) + + static constexpr auto QK_SG_M = get<0>(SubgroupTileShapeQK{}); + static constexpr auto QK_SG_N = get<1>(SubgroupTileShapeQK{}); + static constexpr auto QK_SG_K = get<2>(SubgroupTileShapeQK{}); + + static constexpr bool is_var_len = + cutlass::fmha::collective::is_variable_length_v< + tuple_element_t<3, ProblemShapeType>>; + + using FragsShapeS = decltype(cute::shape_div( + take<0, 2>(SubgroupTileShapeQK{}), + take<0, 2>(MmaAtomShape()))); // 8, 64, 32 / 8, 16, 16 (1, 4) + static constexpr int Vec = + (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // 8 + static constexpr int FragsM = get<0>(FragsShapeS{}); + static constexpr int FragsNS = get<1>(FragsShapeS{}); // 4 + + static constexpr uint32_t MaxThreadsPerBlock = + size(SubgroupLayout{}) * SubgroupSize; + using CopyThreadShape = Shape<_1, Int>; + + using traits_load_Q = Copy_Traits; + using atom_load_Q = Copy_Atom; + using val_layout_load_Q = decltype(make_layout( + shape_div(typename traits_load_Q::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_Q = decltype(make_tiled_copy( + atom_load_Q{}, Layout{}, val_layout_load_Q{})); + + using traits_load_K = Copy_Traits; + using atom_load_K = Copy_Atom; + using val_layout_load_K = decltype(make_layout( + shape_div(typename traits_load_K::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_K = decltype(make_tiled_copy( + atom_load_K{}, Layout{}, val_layout_load_K{})); + + using traits_load_V = Copy_Traits; + using atom_load_V = Copy_Atom; + using val_layout_load_V = decltype(make_layout( + shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_V = decltype(make_tiled_copy( + atom_load_V{}, Layout{}, val_layout_load_V{})); + + template + static constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ElementQ const *ptr_Q; + StrideQ dQ; + ElementK const *ptr_K; + StrideK dK; + ElementV const *ptr_V; + StrideV dV; + float const *ptr_q_scale; + float const *ptr_k_scale; + float const *ptr_v_scale; + ElementK const *ptr_K_cache; + StrideK dK_cache; + ElementV const *ptr_V_cache; + StrideV dV_cache; + // Paged KV Cache + int const *ptr_page_table; + int page_size; + int const *num_pages_per_seq; + int window_left; + int window_right; + }; + + struct Params { + XE_Copy_Q gmem_tiled_copy_q; + XE_Copy_K gmem_tiled_copy_k; + XE_Copy_V gmem_tiled_copy_v; + float const *ptr_q_scale; + float const *ptr_k_scale; + float const *ptr_v_scale; + XE_Copy_K gmem_tiled_copy_k_cache; + XE_Copy_V gmem_tiled_copy_v_cache; + // Paged KV Cache + int const *ptr_page_table; + int page_size; + int const *num_pages_per_seq; + int window_left; + int window_right; + }; + + // + // Methods + // + + FlashChunkPrefillMma() = default; + + static constexpr Params + to_underlying_arguments(ProblemShapeType const &problem_shape, + Arguments const &args, void *workspace) { + (void)workspace; + + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, + seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; + + auto tensorQ = make_tensor( + make_gmem_ptr(args.ptr_Q), + make_layout(make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), + args.dQ)); + auto tensorK = make_tensor( + make_gmem_ptr(args.ptr_K), + make_layout(make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch), + args.dK)); + auto tensorV = make_tensor( + make_gmem_ptr(args.ptr_V), + make_layout(make_shape(num_heads_kv * head_size_vo, seq_len_kv, batch), + args.dV)); + auto tensorK_cache = + make_tensor(make_gmem_ptr(args.ptr_K_cache), + make_layout(make_shape(seq_len_kv_cache, + num_heads_kv * head_size_qk, batch), + args.dK_cache)); + auto tensorV_cache = make_tensor( + make_gmem_ptr(args.ptr_V_cache), + make_layout( + make_shape(num_heads_kv * head_size_vo, seq_len_kv_cache, batch), + args.dV_cache)); + + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; + XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; + XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; + XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; + XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; + + return Params{copyQ, + copyK, + copyV, + args.ptr_q_scale, + args.ptr_k_scale, + args.ptr_v_scale, + copyK_cache, + copyV_cache, + args.ptr_page_table, + args.page_size, + args.num_pages_per_seq, + args.window_left, + args.window_right}; + } + + // FP8 Q and FP8 K tensors are converted to BF16 tensors using descale factors + // GEMM is computed in BF16 precision (FP8 not supported in BMG) + template + CUTLASS_DEVICE void mmaQK(FragQccum &accum, TensorQ gQ, TensorK gK, + FragSrc const &frag_src, int const &k_tile_count, + Params const ¶ms, bool is_KV_cache, + float q_scale, float k_scale) { + + auto &gmem_tiled_copy_k = + is_KV_cache ? params.gmem_tiled_copy_k_cache : params.gmem_tiled_copy_k; + + int thread_idx = static_cast(ThreadIdxX()); + auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx); + auto thr_copy_K = gmem_tiled_copy_k.get_slice(thread_idx); + // Instantiate the MMA object + TiledMmaQK tiled_mma; + // To make all threads in a warp have the same global tensors pass in the + // index of thread 0 in each warp + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = + sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx); + auto thread_mma_k = tiled_mma.get_slice(0); + + Tensor tCgQ = thread_mma_q.partition_A(gQ); + Tensor tCgK = thread_mma_k.partition_B(gK); + + // Create fragments + // TODO(Codeplay): fix this, this is probably not general + using TCrQ_Type = cute::conditional_t, uint8_t, ElementQ>; + using TCrK_Type = cute::conditional_t, uint8_t, ElementK>; + Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0,3>(tCgQ.shape()))); + Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0,3>(tCgK.shape()))); + + // Retile registers for copies + Tensor tQrQ = thr_copy_Q.retile_D(tCrQ); + Tensor tKrK = thr_copy_K.retile_D(tCrK); + + // Retile global tile for copies + Tensor tQgQ = thr_copy_Q.retile_S(tCgQ); + Tensor tKgK = thr_copy_K.retile_S(tCgK); + + // + // Mainloop + // + for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { + copy(params.gmem_tiled_copy_q, tQgQ(_, _, _, k_tile), tQrQ); + copy(gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK); + + // FP8 path: Convert FP8 fragments to BF16 + if constexpr (is_fp8_v || is_fp8_v) { + auto tCrQ_fp16 = make_fragment_like(tCrQ); + auto tCrK_fp16 = make_fragment_like(tCrK); + + if constexpr (is_fp8_v) { + convert_and_descale(tCrQ, tCrQ_fp16, q_scale); + } else { + // If Q is already FP16, copy it. + copy(tCrQ, tCrQ_fp16); + } + + if constexpr (is_fp8_v) { + convert_and_descale(tCrK, tCrK_fp16, k_scale); + } else { + copy(tCrK, tCrK_fp16); + } + + // GEMM is computed on the BF16 tensors + cute::gemm(tiled_mma, accum, tCrQ_fp16, tCrK_fp16, frag_src); + } else { + // BF16 path + cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src); + } + +#if 0 +#define PRINT(x) \ + print(#x ": "); \ + print(x); \ + print("\n"); + if (cute::thread(0, 0)) { + print("======================= Q: \n"); + PRINT(gQ); + PRINT(tCrQ); + PRINT(tCgQ); + PRINT(tQrQ); + PRINT(tQgQ); + + print("===================== K :\n"); + PRINT(gK); + PRINT(tCrK); + PRINT(tCgK); + PRINT(tKrK); + PRINT(tKgK); + + print("===================== Config: \n"); + PRINT(MaxThreadsPerBlock); + PRINT(SubgroupTileShapeQK{}); + } +#undef PRINT +#endif + } + } + + template + CUTLASS_DEVICE auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast *>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + } + + // FP8 V tensor is converted to BF16 tensor using descale factor + // P tensor (softmax output) is in FP32 precision (converted to BF16) + // GEMM is computed in BF16 precision (FP8 not supported in BMG) + template + CUTLASS_DEVICE void mmaPV(FragQccum &accum, FragS const &tSr, TensorV gV, + FragSrc const &frag_src, Params const ¶ms, + bool is_KV_cache, float v_scale) { + + auto &gmem_tiled_copy_v = + is_KV_cache ? params.gmem_tiled_copy_v_cache : params.gmem_tiled_copy_v; + + int thread_idx = static_cast(ThreadIdxX()); + // Instantiate the MMA object + TiledMmaPV tiled_mma; + // Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid + // Register spill + Tensor gV_ = take<0, 3>( + local_tile(gV, select<1, 2>(TileShapePV{}), make_coord(_, _))); + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = + sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + Tensor tCgV = thread_mma.partition_B(gV_); + using TCrV_Type = cute::conditional_t, uint8_t, ElementV>; + Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, take<0,3>(tCgV.shape()))); + + // Partition the copying of A and B tiles across the threads + auto gmem_thr_copy_V = gmem_tiled_copy_v.get_slice(thread_idx); + Tensor tVrV = gmem_thr_copy_V.retile_D(tCrV); + Tensor tVgV = gmem_thr_copy_V.retile_S(tCgV); + +#if CUTLASS_ENABLE_DEBUG_PRINTS +#define PRINT(x) \ + print(#x ": "); \ + print(x); \ + print("\n"); + if (cute::thread(LOG_THREAD, LOG_GROUP)) { + print("===================== V :\n"); + PRINT(gV); + PRINT(tCrV); + PRINT(tCgV); + PRINT(tVrV); + PRINT(tVgV); + + print("===================== Config: \n"); + PRINT(MaxThreadsPerBlock); + PRINT(SubgroupTileShapePV{}); + } +#undef PRINT +#endif + + // 7) Convert S to P (FP32 -> BF16) + Tensor tPr = convert_type(tSr); + // + // Mainloop + // + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tile_count; i++) { + copy(gmem_tiled_copy_v, tVgV(_, _, _, i), tVrV); + + if constexpr (is_fp8_v) { + auto tCrV_fp16 = make_fragment_like(tCrV); + convert_and_descale(tCrV, tCrV_fp16, v_scale); + + cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV_fp16, frag_src(_,_,_,i)); + } else { + cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV, frag_src(_,_,_,i)); + } + } + } + + // SequenceLengthShape = Shape + // For Fixed Sequence Length, ProblemShape = Shape For Variable Sequence Length, ProblemShape = Shape + template + CUTLASS_DEVICE static constexpr Params + get_updated_copies(Params const ¶ms, ProblemShape const &problem_shape, + SequenceLengthShape const &sequence_length_shape, + int const &l_coord, int const &q_head_coord = 0) { + auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = + select<0, 1, 2, 6, 7>(problem_shape); + auto [seq_len_qo, seq_len_kv, seq_len_kv_cache] = sequence_length_shape; + auto q_group_size = num_heads_q / num_heads_kv; + auto kv_head_coord = q_head_coord / q_group_size; + int offset_q = 0, offset_k = 0, offset_v = 0, offset_k_cache = 0, + offset_v_cache = 0; + int total_seq_len_kv_cache = 0; + if constexpr (is_var_len) { + auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; + auto kv_cumulative_length = get<4>(problem_shape).cumulative_length; + auto kv_cached_cumulative_length = + get<5>(problem_shape).cumulative_length; + + offset_q = num_heads_q * head_size_qk * qo_cumulative_length[l_coord] + + q_head_coord * head_size_qk; + + offset_k = num_heads_kv * head_size_qk * kv_cumulative_length[l_coord] + + kv_head_coord * head_size_qk; + offset_v = num_heads_kv * head_size_vo * kv_cumulative_length[l_coord] + + kv_head_coord * head_size_vo; + offset_k_cache = seq_len_kv_cache == 0 + ? 0 + : PagedKV? // For page_kv, there is no batch dimension. + kv_head_coord * head_size_qk + : num_heads_kv * head_size_qk * kv_cached_cumulative_length[l_coord] + kv_head_coord * head_size_qk; + offset_v_cache = seq_len_kv_cache == 0 + ? 0 + : PagedKV? // For page_kv, there is no batch dimension. + kv_head_coord * head_size_vo + : num_heads_kv * head_size_vo * kv_cached_cumulative_length[l_coord] + kv_head_coord * head_size_vo; + total_seq_len_kv_cache = get<5>(problem_shape).total_length; + } else { + offset_q = num_heads_q * head_size_qk * seq_len_qo * l_coord + + q_head_coord * head_size_qk; + + offset_k = num_heads_kv * head_size_qk * seq_len_kv * l_coord + + kv_head_coord * head_size_qk; + offset_v = num_heads_kv * head_size_vo * seq_len_kv * l_coord + + kv_head_coord * head_size_vo; + offset_k_cache = + seq_len_kv_cache == 0 + ? 0 : + PagedKV? + kv_head_coord * head_size_qk + : num_heads_kv * head_size_qk * seq_len_kv_cache * l_coord + kv_head_coord * head_size_qk; + offset_v_cache = + seq_len_kv_cache == 0 + ? 0 : + PagedKV? + kv_head_coord * head_size_vo + : num_heads_kv * head_size_vo * seq_len_kv_cache * l_coord + kv_head_coord * head_size_vo; + total_seq_len_kv_cache = batch * seq_len_kv_cache; + } + + auto q_traits = + static_cast(params.gmem_tiled_copy_q); + const ElementQ *q_ptr = (const ElementQ *)q_traits.base_ptr; + auto k_traits = + static_cast(params.gmem_tiled_copy_k); + const ElementK *k_ptr = (const ElementK *)k_traits.base_ptr; + auto v_traits = + static_cast(params.gmem_tiled_copy_v); + const ElementV *v_ptr = (const ElementV *)v_traits.base_ptr; + auto k_traits_cache = + static_cast(params.gmem_tiled_copy_k_cache); + const ElementK *k_cache_ptr = (const ElementK *)k_traits_cache.base_ptr; + auto v_traits_cache = + static_cast(params.gmem_tiled_copy_v_cache); + const ElementV *v_cache_ptr = (const ElementV *)v_traits_cache.base_ptr; + // NHD format{batch, seq_len, head, dim_head} + // stride {seq_len*head*dim_head, head*dim_head, dim_head, 1} + auto shape_q = + make_shape(static_cast(seq_len_qo), head_size_qk * num_heads_q, 1); + StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q); + auto shape_k = make_shape(static_cast(seq_len_kv), + num_heads_kv * head_size_qk, 1); + StrideK stride_k = cutlass::make_cute_packed_stride(StrideK{}, shape_k); + + auto shape_v = make_shape(head_size_vo * num_heads_kv, + static_cast(seq_len_kv), 1); + StrideV stride_v = cutlass::make_cute_packed_stride(StrideV{}, shape_v); + + auto shape_k_cache = make_shape(static_cast(PagedKV? total_seq_len_kv_cache : seq_len_kv_cache), + head_size_qk * num_heads_kv, 1); + StrideK stride_k_cache = + cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache); + auto shape_v_cache = make_shape(head_size_vo * num_heads_kv, + static_cast(PagedKV? total_seq_len_kv_cache : seq_len_kv_cache), 1); + StrideV stride_v_cache = + cutlass::make_cute_packed_stride(StrideV{}, shape_v_cache); + auto tensorQ = make_tensor(make_gmem_ptr(q_ptr + offset_q), + make_layout(shape_q, stride_q)); + auto tensorK = make_tensor(make_gmem_ptr(k_ptr + offset_k), + make_layout(shape_k, stride_k)); + auto tensorV = make_tensor(make_gmem_ptr(v_ptr + offset_v), + make_layout(shape_v, stride_v)); + auto tensorK_cache = + make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), + make_layout(shape_k_cache, stride_k_cache)); + auto tensorV_cache = + make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), + make_layout(shape_v_cache, stride_v_cache)); + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; + XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; + XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; + XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; + XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; + + return Params{copyQ, + copyK, + copyV, + params.ptr_q_scale, + params.ptr_k_scale, + params.ptr_v_scale, + copyK_cache, + copyV_cache, + params.ptr_page_table, + params.page_size, + params.num_pages_per_seq, + params.window_left, + params.window_right}; + } +}; + +} // namespace cutlass::flash_attention::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp new file mode 100644 index 0000000000..849f65971b --- /dev/null +++ b/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing online softmax. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/detail/layout.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace flash_attention { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class FlashChunkPrefillSoftmaxEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + + +template +class FlashChunkPrefillSoftmaxEpilogue { +public: + + // + // Type Aliases + // + using DispatchPolicy = epilogue::IntelXeXMX16; + using Element = Element_; + + static constexpr bool CausalMask = CausalMask_; + static constexpr bool LocalMask = LocalMask_; + + using GmemTiledCopyOut = void; + + // Host side epilogue arguments + struct Arguments { + Element const scale; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + static constexpr Params to_underlying_arguments(Arguments const &args) { + constexpr double kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + Element val = args.scale * static_cast(kLog2e); + return Params{val}; + } + + template + static size_t get_workspace_size() { + return 0; + } + + template + static cutlass::Status initialize_workspace() { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement() { + return true; + } + + CUTLASS_HOST_DEVICE + FlashChunkPrefillSoftmaxEpilogue(Params const ¶ms_) : params(params_) {} + + template + CUTLASS_DEVICE void scale_exp_log2(FragAcc &frag_s, FragMax const &max, FragSum &sum) { + auto g = compat::get_nd_item<1>().get_sub_group(); + const auto max_scale = max * params.scale; + CUTLASS_PRAGMA_UNROLL + for (int indx = 0; indx < Vec * FragsM; indx++) { + const auto max_scale_bcast = group_broadcast(g, max_scale, indx); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + auto base_indx = indx + (z * Vec * FragsM); + if constexpr (LocalMask) { + if ((std::isinf(max_scale_bcast) && max_scale_bcast < 0) || + (std::isinf(frag_s(base_indx)) && frag_s(base_indx) < 0)) { + frag_s(base_indx) = 0.f; + // continue; + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + sum(indx) += frag_s(base_indx); + } + } + } + + template + CUTLASS_DEVICE void reduce_max(FragSrc &src, FragMax &max) { + auto sg = compat::get_nd_item<1>().get_sub_group(); + CUTLASS_PRAGMA_UNROLL + for (int indx = 0; indx < Vec * FragsM; indx++) { + auto maxptr = group_broadcast(sg, max, indx); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + auto base_indx = indx + (z * Vec * FragsM); + maxptr = sycl::max(maxptr, src(base_indx)); + src(base_indx) *= params.scale; + } + maxptr = reduce_over_group(sg, maxptr, sycl::maximum<>()); + if (indx == sg.get_local_id()[0]) { + max = maxptr; + } + } + } + + template + CUTLASS_DEVICE void operator()(bool is_first, FragAcc &frag_s, FragMax &max, FragSum &sum, FragOut &out) { + auto max_prev = max; + using FragAccLayout = typename FragAcc::layout_type; + using FragOutLayout = typename FragOut::layout_type; + constexpr int Vec = get<0>(FragAccLayout{}.shape()); + constexpr int FragsM = get<1>(FragAccLayout{}.shape()); + constexpr int FragsNAcc = get<2>(FragAccLayout{}.shape()); + constexpr int FragsNOut = size(select<2,3>(FragOutLayout{}.shape())); + reduce_max(frag_s, max); + static_assert(Vec * FragsM % 8 == 0, " No. of attention rows per subgroup should be >= 1 MMA Atom worth of rows."); + if (!is_first) { + auto sg = compat::get_nd_item<1>().get_sub_group(); + Element max_scale{max * params.scale}; + Element exp_scale; + if constexpr (LocalMask) { + if ((std::isinf(max_scale) && max_scale < 0) || (std::isinf(max_prev) && max_prev < 0)) { + exp_scale = 0.f; + } else { + exp_scale = sycl::native::exp2(max_prev * params.scale - max_scale); + } + } else { + exp_scale = sycl::native::exp2(max_prev * params.scale - max_scale); + } + + CUTLASS_PRAGMA_UNROLL + for (int indx = 0; indx < Vec * FragsM; indx++) { + auto max_scale_bcast = group_broadcast(sg, max_scale, indx); + auto exp_scale_bcast = group_broadcast(sg, exp_scale, indx); + sum(indx) *= exp_scale_bcast; + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsNAcc; z++) { + auto base_indx = indx + (z * Vec * FragsM); + if constexpr (LocalMask) { + if ((std::isinf(max_scale_bcast) && max_scale_bcast < 0) || + (std::isinf(frag_s(base_indx)) && frag_s(base_indx) < 0)) { + frag_s(base_indx) = 0.f; + // continue; + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + sum(indx) += frag_s(base_indx); + } + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsNOut; z++) { + auto base_indx = indx + (z * Vec * FragsM); + out(base_indx) *= exp_scale_bcast; + } + } + } else { + scale_exp_log2(frag_s, max, sum); + } + } + Params params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace flash_attention +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/applications/flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp b/applications/flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp new file mode 100644 index 0000000000..6d429d52bc --- /dev/null +++ b/applications/flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::flash_attention { + +namespace kernel { + +struct XeFlashIndividualTileScheduler { + + struct Params { + dim3 grid; + // FastDivmod divmod_num_heads; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeFlashIndividualTileScheduler(Params const ¶ms) : params(params) {} + + template + static Params to_underlying_arguments(ProblemSize const &problem_size, + KernelHardwareInfo hw_info, + TileShape const &tile_shape) { + using namespace cute; + // problem_size = [batch, num_heads_q , num_heads_kv, seq_len_qo, + // seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] + + // dim3 grid(size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape))), + // size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))), + // size(shape<0>(problem_size) * shape<1>(problem_size))); + + int batch = size<0>(problem_size); + int num_heads_q = size<1>(problem_size); + int num_heads_kv = size<2>(problem_size); + int seq_len_qo = + size<3>(problem_size); // if varlen seq_len_qo = max_seq_len + int seq_len_kv = + size<4>(problem_size); // if varlen seq_len_qo = max_seq_len + int seq_len_kv_cache = size<5>(problem_size); + int head_size_qk = size<6>(problem_size); + int head_size_vo = size<7>(problem_size); + auto group_heads_q = num_heads_q / num_heads_kv; + + dim3 grid(size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))), + size(shape<1>(problem_size)), size(shape<0>(problem_size))); + return Params{grid}; + } + + + template static dim3 get_grid_shape(Params const ¶ms) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { return valid_; } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(BlockIdxX(), BlockIdxY(), BlockIdxZ()); + } + + CUTLASS_DEVICE + XeFlashIndividualTileScheduler &operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct XeFlashPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_seq_len_block; + FastDivmod divmod_head_size_block; + FastDivmod divmod_num_heads; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + XeFlashPersistentTileScheduler(Params const ¶ms) + : block_idx(BlockIdxX()), params(params) {} + + template + static Params to_underlying_arguments(ProblemSize const &problem_size, + KernelHardwareInfo hw_info, + TileShape const &tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments " + "KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + } + + CUTLASS_TRACE_HOST( + "to_underlying_arguments(): Setting persistent grid SM count to " + << sm_count); + hw_info.sm_count = sm_count; + + // problem_size = [batch, num_heads_q, numhead_kv, seq_len_qo, seq_len_kv, + // seq_len_kv_cache, head_size_qk, head_size_vo] + int num_head_size_blocks = + size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape))); + int num_seq_len_blocks = + size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))); + int num_blocks = num_seq_len_blocks * num_head_size_blocks * + size(shape<0>(problem_size) * shape<1>(problem_size)); + + return Params{num_blocks, + {num_seq_len_blocks}, + {num_head_size_blocks}, + {shape<1>(problem_size)}, + hw_info}; + } + + template static dim3 get_grid_shape(Params const ¶ms) { + auto queue = compat::get_default_queue(); + auto dev = queue.get_device(); + const size_t maxSubgroups = + dev.template get_info(); + // TODO (Codeplay): revert this back to std::min(params.num_blocks, + // params.hw_info.sm_count) once performance issue is fixed. + dim3 grid( + std::min(params.num_blocks, + ceil_div(params.hw_info.sm_count * maxSubgroups, Num_SGs)), + 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { return block_idx < params.num_blocks; } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int seq_len_block, head_size_block, bidh; + params.divmod_head_size_block(block_decode, head_size_block, block_decode); + params.divmod_seq_len_block(block_decode, seq_len_block, block_decode); + params.divmod_num_heads(block_decode, bidh, block_decode); + return make_coord(head_size_block, seq_len_block, block_decode, bidh); + } + + CUTLASS_DEVICE + XeFlashPersistentTileScheduler &operator++() { + block_idx += GridDimX(); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +} // namespace kernel + +struct IndividualScheduler {}; +struct PersistentScheduler {}; + +namespace detail { + +template +struct TileSchedulerSelector { + static_assert(cutlass::detail::dependent_false, + "Could not select a tile scheduler for given parameters."); +}; + +// Default (void) maps to XeFlashIndividualTileScheduler +template +struct TileSchedulerSelector< + void, ArchTag, + cute::enable_if_t>> { + using Scheduler = + typename TileSchedulerSelector::Scheduler; +}; + +template +struct TileSchedulerSelector< + IndividualScheduler, ArchTag, + cute::enable_if_t>> { + using Scheduler = kernel::XeFlashIndividualTileScheduler; +}; + +template +struct TileSchedulerSelector< + PersistentScheduler, ArchTag, + cute::enable_if_t>> { + using Scheduler = kernel::XeFlashPersistentTileScheduler; +}; +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::flash_attention diff --git a/applications/flash_attention_v2/kernel/xe_chunk_prefill.hpp b/applications/flash_attention_v2/kernel/xe_chunk_prefill.hpp new file mode 100644 index 0000000000..30b1f4dc86 --- /dev/null +++ b/applications/flash_attention_v2/kernel/xe_chunk_prefill.hpp @@ -0,0 +1,685 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice,this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_mma.hpp" +namespace cutlass::flash_attention::kernel { + +template +class FMHAPrefillChunk; +/////////////////////////////////////////////////////////////////////////////// +template +class FMHAPrefillChunk { + +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + + // ProblemShape: + static_assert( + rank(ProblemShape{}) == 8, + "ProblemShape{} should be "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + using TiledMmaQK = typename CollectiveMainloop::TiledMmaQK; + using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementQ = typename CollectiveMainloop::ElementQ; + using StrideQ = typename CollectiveMainloop::StrideQ; + using ElementK = typename CollectiveMainloop::ElementK; + using StrideK = typename CollectiveMainloop::StrideK; + using ElementV = typename CollectiveMainloop::ElementV; + using StrideV = typename CollectiveMainloop::StrideV; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using CollectiveSoftmaxEpilogue = CollectiveSoftmaxEpilogue_; + using SoftmaxArguments = typename CollectiveSoftmaxEpilogue::Arguments; + using SoftmaxParams = typename CollectiveSoftmaxEpilogue::Params; + + static_assert(cute::is_void_v or + cute::is_same_v or + cute::is_same_v, + "Unsupported TileScheduler for Intel Xe."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementO = typename CollectiveEpilogue::ElementO; + using StrideO = typename CollectiveEpilogue::StrideO; + using ElementLSE = typename CollectiveEpilogue::ElementLSE; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + using TileShapeOutput = typename CollectiveEpilogue::TileShapeOutput; + using TiledMmaOutput = typename CollectiveEpilogue::TiledMmaOutput; + + static_assert( + cute::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + // MSVC requires the cast to fix a warning-as-error. + static constexpr int SharedStorageSize = 0; + + static constexpr bool CausalMask = CollectiveMainloop::CausalMask; + static constexpr bool LocalMask = CollectiveMainloop::LocalMask; + + static_assert(!(CausalMask && LocalMask), "Cannot be both causal and local"); + static constexpr bool PagedKV = CollectiveMainloop::PagedKV; + + + static constexpr int SubgroupSize = + CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = + CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16 + + static constexpr int QK_BLK_M = CollectiveMainloop::QK_BLK_M; + static constexpr int QK_BLK_N = CollectiveMainloop::QK_BLK_N; + static constexpr int QK_BLK_K = CollectiveMainloop::QK_BLK_K; + + static constexpr int QK_ATOM_N = CollectiveMainloop::QK_ATOM_N; + static constexpr int QK_ATOM_K = CollectiveMainloop::QK_ATOM_K; + + static constexpr int QK_SG_M = CollectiveMainloop::QK_SG_M; + + static constexpr int Epilogue_BLK_N = get<1>(TileShapeOutput{}); + static constexpr int Epilogue_BLK_K = get<2>(TileShapeOutput{}); + + static constexpr int PV_ATOM_M = CollectiveMainloop::PV_ATOM_M; + static constexpr int PV_ATOM_N = CollectiveMainloop::PV_ATOM_N; + static constexpr int PV_ATOM_K = CollectiveMainloop::PV_ATOM_K; + + static constexpr auto Num_SGs = PV_ATOM_N * PV_ATOM_M * PV_ATOM_K; + static constexpr int Vec = CollectiveMainloop::Vec; + static constexpr int FragsM = CollectiveMainloop::FragsM; + // The FragsN here used for Creation of S matrix so we use the FragsN for S + // shape + static constexpr int FragsN = CollectiveMainloop::FragsNS; + + static constexpr int VSlicer = + get<1>(TileShapeOutput{}) / + (get<1>(TileShapePV{}) * PV_ATOM_N); // ceil_div(FragsNOut,FragsNS); + using AccumeShape = decltype(make_shape( + Int{}, Int{}, get<1>(TileShapePV{}) / get<1>(MmaAtomShape()), + Int{})); + + static constexpr bool is_var_len = CollectiveMainloop::is_var_len; + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + + // Device side arguments + struct Arguments { + gemm::GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + SoftmaxArguments softmax{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + }; + + // Kernel entry point API + struct Params { + gemm::GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + SoftmaxParams softmax; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the + // aliased type. + static Params to_underlying_arguments(Arguments const &args, + void *workspace) { + (void)workspace; + return {args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments( + args.problem_shape, args.mainloop, workspace), + CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax), + CollectiveEpilogue::to_underlying_arguments( + args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments( + args.problem_shape, args.hw_info, TileShapeOutput{})}; + } + + static bool can_implement(Arguments const &args) { + bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or + (args.mode == gemm::GemmUniversalMode::kBatched && + rank(ProblemShape{}) == 4); + return mode_implementable; + } + + static int get_workspace_size(Arguments const &args) { return 0; } + + static cutlass::Status + initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const ¶ms) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } + + CUTLASS_DEVICE + Shape + get_sequence_length_shape(ProblemShape const &problem_shape, + int const &batch) { + if constexpr (is_var_len) { + return cutlass::fmha::collective::apply_variable_length( + select<3, 4, 5>(problem_shape), batch); + } else { + return select<3, 4, 5>(problem_shape); + } + } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char *smem_buf) { + SharedStorage &shared_storage = + *reinterpret_cast(smem_buf); + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + // Separate out problem shape for convenience + + // "ProblemShape{} should be "); + auto batch = get<0>(params.problem_shape); + auto num_heads_q = get<1>(params.problem_shape); + auto num_heads_kv = get<2>(params.problem_shape); + + auto &head_size_qk = get<6>(params.problem_shape); + auto &head_size_vo = get<7>(params.problem_shape); + // Preconditions + static_assert(cute::rank(StrideQ{}) == 3, + "StrideQ must be rank-3: [seq_len_qo, head_size_qk, batch * " + "num_heads_q]."); + static_assert(cute::rank(StrideK{}) == 3, + "StrideK must be rank-3: [head_size_qk, seq_len_kv, batch * " + "num_heads_kv]."); + static_assert(cute::rank(StrideV{}) == 3, + "StrideV must be rank-3: [seq_len_kv, head_size_vo, batch * " + "num_heads_kv]."); + + int thread_idx = int(ThreadIdxX()); + int sub_group_id = thread_idx / SubgroupSize; + + TileScheduler tile_scheduler{params.scheduler}; + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = + tile_scheduler + .get_block_coord(); // head_size_blk_idx, seq_len_blk_idx, + // batch_blk_idx, num_heads_blk_idx + + auto blk_m_coord = get<0>(blk_coord); // seq_len_blk_idx + auto blk_n_coord = 0; // nums_head_blk_idx + auto q_head_coord = get<1>(blk_coord); // q_heads_idx + auto batch_coord = get<2>(blk_coord); // batch_blk_idx + + // For variable sequence length case, batch is considered to be 1 (same + // as group gemm). For fixed sequence length case, the l_coord is the + // weighted sum of both batch_coord and num_heads_coord. Flash Attention + // implementation combines batch and num_heads to calculate the total + // batch_size. iff is_var_len: batch_size = num_heads (as each batch + // would have it's own seq_len_qo and seq_len_kv) iff !is_var_len: + // batch_size = batch * num_heads + // auto blk_l_coord = q_head_coord; + + // Get problem shape for the current batch_blk_idx. For variable + // sequence length, it loads the sequence length from Global memory for + // the given batch_blk_idx and returns the appropriate problem_shape. + // For fixed sequence length, sequence_length_shape == select<3, 4, + // 5>(params.problem_shape). sequence_length_shape = [batch, + // num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, + // head_size_qk, head_size_vo] + auto sequence_length_shape = + get_sequence_length_shape(params.problem_shape, batch_coord); + + auto [seq_len_qo, seq_len_kv, seq_len_kv_cache] = sequence_length_shape; + // int seq_len_kv_total = seq_len_kv_cache + seq_len_kv; + // For variable sequence length case, batch is considered to be 1 (same + // as group gemm). For fixed sequence length case, the l_coord is the + // weighted sum of both batch_coord and num_heads_coord. Flash Attention + // implementation combines batch and num_heads to calculate the total + // batch_size. iff is_var_len: batch_size = num_heads (as each batch + // would have it's own seq_len_qo and seq_len_kv) iff !is_var_len: + // batch_size = batch * num_heads + + // Calculate the seq_len_idx (blk_m_coord * get<0>(TileShapeOutput{})) + // and check if it is still within bounds of the actual seq_len_qo + // (get<0>(sequence_length_shape)). + if (blk_m_coord * get<0>(TileShapeOutput{}) >= + seq_len_qo) { + continue; + } + + const int seq_coord = + cute::min(seq_len_qo, (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % + seq_len_qo); + auto offset = cute::min(seq_len_qo, seq_len_kv); //(2048, 1024) + auto discard_seq_coord = seq_len_qo - offset; // 1024 + auto full_tile_offset = seq_len_kv - offset; // 0 + + const int seq_len = + CausalMask + ? full_tile_offset + + cute::min(seq_len_kv, seq_coord - discard_seq_coord) + + QK_SG_M + : seq_len_kv; + + const int kv_splits_new = cute::ceil_div(seq_len, QK_BLK_N); + const int kv_splits_cache = cute::ceil_div(seq_len_kv_cache, QK_BLK_N); + const int kv_splits = kv_splits_cache + kv_splits_new; + + int tiles_per_page = params.mainloop.page_size / QK_BLK_N; + + if (CausalMask && seq_coord < discard_seq_coord) { // 1024 =0 + continue; + } + + auto q_group_size = num_heads_q / num_heads_kv; + auto kv_head_coord = q_head_coord / q_group_size; + + // Descale tensors are shaped (batch size * # heads) + // Each head has a seperate scale factor + // Q, K, V tensors have seperate scaling factors + const float q_scale_val = params.mainloop.ptr_q_scale == nullptr + ? 1.f + : params.mainloop.ptr_q_scale[batch_coord * num_heads_q + q_head_coord]; + const float k_scale_val = params.mainloop.ptr_k_scale == nullptr + ? 1.f + : params.mainloop.ptr_k_scale[batch_coord * num_heads_kv + kv_head_coord]; + const float v_scale_val = params.mainloop.ptr_v_scale == nullptr + ? 1.f + : params.mainloop.ptr_v_scale[batch_coord * num_heads_kv + kv_head_coord]; + + Tensor mQ_mkl = cute::get_xe_tensor( + make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) + + Tensor mK_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv, head_size_qk, 1)); //(n,k,l) + Tensor mV_nkl = cute::get_xe_tensor( + make_shape(head_size_vo, seq_len_kv, 1)); //(n,k,l) + Tensor mK_cache_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l) + Tensor mV_cache_nkl = cute::get_xe_tensor( + make_shape(head_size_vo, seq_len_kv_cache, 1)); // (n_cache,k,l) + + // block_size and head_size are the same size. So no coord is needed. + Tensor mQ_mk = mQ_mkl(_, _, 0); + + Tensor mK_nk = mK_nkl(_, _, 0); // (n,k) + Tensor mV_nk = mV_nkl(_, _, 0); + + Tensor mK_cache_nk = mK_cache_nkl(_, _, 0); // (n_cache, k) + Tensor mV_cache_nk = mV_cache_nkl(_, _, 0); // (n_cache, k) + + auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), + Step<_1, X, _1>{}); + auto gK = local_tile(mK_nk, TileShapeQK{}, make_coord(_, _, _), + Step{}); + + auto gV = local_tile(mV_nk, TileShapeOutput{}, + make_coord(_, blk_n_coord, _), Step{}); + auto gK_cache = local_tile(mK_cache_nk, TileShapeQK{}, + make_coord(_, _, _), Step{}); + auto gV_cache = + local_tile(mV_cache_nk, TileShapeOutput{}, + make_coord(_, blk_n_coord, _), Step{}); + + auto mainloop_params = CollectiveMainloop::get_updated_copies( + params.mainloop, params.problem_shape, sequence_length_shape, + batch_coord, q_head_coord); + + + // we limit the horisontal size to two subgroup, the empirical resutls + // show that reading the two cacheline side by side in gives better + // performance and anything after that does not have an effect on + // performance. // (64 here for float b float when possible and loop over + // to cover all the data needed) + auto tiled_prefetch_q = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_q); + auto tiled_prefetch_k = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_k); + auto tiled_prefetch_v = cute::prefetch_selector< + Shape, + Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_v); + auto tiled_prefetch_k_cache = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_k_cache); + auto tiled_prefetch_v_cache = cute::prefetch_selector< + Shape, + Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_v_cache); + auto thr_prefetch_Q = tiled_prefetch_q.get_slice(thread_idx); + auto thr_prefetch_K = tiled_prefetch_k.get_slice(thread_idx); + auto thr_prefetch_V = tiled_prefetch_v.get_slice(thread_idx); + auto pQgQ = thr_prefetch_Q.partition_S(gQ); + auto pKgK = thr_prefetch_K.partition_S(gK); + auto pVgV = thr_prefetch_V.partition_S(gV); + // assuming the copy function is the same otherwise this need to have its + // own tile_prefetch + auto pKgK_cache = thr_prefetch_K.partition_S(gK_cache); + auto pVgV_cache = thr_prefetch_V.partition_S(gV_cache); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<3>(pQgQ); i++) { + prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); + } + auto &prefetch_K = + (seq_len_kv_cache == 0) ? tiled_prefetch_k : tiled_prefetch_k_cache; + auto &pKgK1_ = (seq_len_kv_cache == 0) ? pKgK : pKgK_cache; + + int cached_nblock = 0; + if constexpr (PagedKV) { + int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size); + int batch_offset = + is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] + : batch_coord * curr_batch_pages; + cached_nblock = + mainloop_params + .ptr_page_table[batch_offset // page table for this batch + ] * tiles_per_page; // base block idx of physical page + } + // The headsize for both cached and non-cached version is the same + for (int j = 0; j < size<4>(pKgK1_); j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = cached_nblock; i < cached_nblock + DispatchPolicy::Stages; + i++) { + prefetch(prefetch_K, pKgK1_(_, _, _, i, j)); + } + } + + // Allocate the tiled_mma and the accumulators for the (M,N) + // workgroup_shape + Tensor out_reg = make_tensor(AccumeShape{}); + + // There are 16 workitem and 16 max per subgroup, each worktime containt 1 + // max and cumulatively, they calculate the max per subgroup + ElementAccumulator max_reg{-INFINITY}; + // The sum reg each contains a 2d tesnor for 8 x 2 This is number of + // sequence lenght process per subgroup + Tensor sum_reg = + make_tensor(Shape, Int>{}); + + clear(sum_reg); + clear(out_reg); + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + // when causal mask is true. It is not possible to set the scope + // of the barrier to workgroup level as the number n block is + // different for each subgroup due to triangular nature of causal based + // operation + static constexpr int barrier_scope = CausalMask ? 3 : 2; + CUTLASS_PRAGMA_UNROLL + for (int split = 0; split < kv_splits - static_cast(CausalMask); split++) { + barrier_arrive(barrier_scope); + + bool is_KV_cache = split < kv_splits_cache; + // 1) Load KV (performed inside mmaQK) + auto gK_ = is_KV_cache ? gK_cache(_, _, cached_nblock, _) + : gK(_, _, split - kv_splits_cache, _); + auto gV_ = is_KV_cache ? gV_cache(_, _, cached_nblock) + : gV(_, _, split - kv_splits_cache); + // 2) Create Tensor S + Tensor tSr = make_tensor( + Shape, Int, Int>{}); + clear(tSr); + // 3) Perform GEMM S = Q*K + // Then modify layout to LayoutQ = ((seq_leq_q, group_head_q), + // head_size_qk, batch* num_heads_q / group_head_q), which can be merged + // into one gemm for (int i = 0; i < q_group_size; ++i) { + collective_mma.mmaQK(tSr, gQ, gK_, tSr, + ceil_div(head_size_qk, QK_BLK_K), mainloop_params, + is_KV_cache, q_scale_val, k_scale_val); + + if constexpr (LocalMask) { + // Sliding windows + // mask the elements of each tile where j - left > i || j + right < i + const int item_id = thread_idx % SubgroupSize; + int col_idx; + if (split < kv_splits_cache) { + col_idx = item_id + split * cute::min(QK_BLK_N, seq_len_kv_cache) ; + } else { + col_idx = item_id + seq_len_kv_cache + (split - kv_splits_cache) * cute::min(QK_BLK_N, seq_len_kv); + } + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; + n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + int col_ref = seq_len_kv_cache + seq_len_kv - seq_len_qo; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 + bool left_mask = col_idx < cute::max(0, row + row_idx + col_ref - mainloop_params.window_left); + bool right_mask = col_idx > cute::min(seq_len_kv_cache + seq_len_kv, row + row_idx + col_ref + mainloop_params.window_right); + if (left_mask || right_mask) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + + if constexpr(!(CausalMask || LocalMask) && PagedKV) { + // Processing Not divisible, mask padding + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id + split * cute::min(QK_BLK_N, seq_len_kv_cache + seq_len_kv); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 + if (col_idx >= seq_len_kv_cache + seq_len_kv || row_idx + row >= seq_len_qo) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + auto &tiled_prefetch_v_ = + is_KV_cache ? tiled_prefetch_v_cache + : tiled_prefetch_v; + auto &pVgV_ = is_KV_cache ? pVgV_cache : pVgV; + int v_prefetch_idx = is_KV_cache ? PagedKV ? cached_nblock : split + : split - kv_splits_cache; + for (int i = 0; i < size<1>(pVgV_); i++) { + prefetch(tiled_prefetch_v_, pVgV_(_, i, _, v_prefetch_idx)); + } + int next_cached_nblock = split + 1; + bool is_next_KV_cache = next_cached_nblock < kv_splits_cache; + if constexpr (PagedKV) { + if (is_next_KV_cache) { + int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size); + int next_page_logical_idx = + next_cached_nblock * QK_BLK_N / params.mainloop.page_size; + int batch_offset = + is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] + : batch_coord * curr_batch_pages; + bool valid_page = next_page_logical_idx < curr_batch_pages; + // get physical page idx from page table + if (valid_page) { + next_cached_nblock = + params.mainloop.ptr_page_table + [batch_offset + // page table for this batch + next_page_logical_idx // split (tile idx) to logical + // page idx + ] * tiles_per_page + // base block idx of physical page + next_cached_nblock % tiles_per_page; // offset within page + } else { + next_cached_nblock = + curr_batch_pages * + tiles_per_page; // push idx out of bounds to respect the + // boundary between batches + } + } + } + + // 4) Fused softmax + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax(split == 0, tSr, max_reg, sum_reg, out_reg); + + // 5) Perform GEMM O = S*V + collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, + mainloop_params, is_KV_cache, v_scale_val); + // ... prefetch next tile ... + // Prefetch the next Q tile + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<3>(pQgQ); i++) { + prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); + } + + is_KV_cache = is_next_KV_cache; + cached_nblock = next_cached_nblock; + // Prefetch the next K tile + // there is no need to gaurd it with if statememt as prefetch will + // ignore out of bound reading + if constexpr (PagedKV) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<4>(pKgK_cache); j++) { + prefetch(tiled_prefetch_k_cache, pKgK_cache(_, _, _, cached_nblock, j)); + } + } else { + bool sel_prefetch_k = + (split + DispatchPolicy::Stages) < kv_splits_cache; + auto &prefetch_k_selector = + sel_prefetch_k ? tiled_prefetch_k_cache : tiled_prefetch_k; + auto &pKgK_ = sel_prefetch_k ? pKgK_cache : pKgK; + int k_prefetch_idx = + sel_prefetch_k + ? PagedKV ? cached_nblock : split + DispatchPolicy::Stages + : split + DispatchPolicy::Stages - kv_splits_cache; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<4>(pKgK_); j++) { + prefetch(prefetch_k_selector, pKgK_(_, _, _, k_prefetch_idx, j)); + } + } + barrier_wait(barrier_scope); + } + + if constexpr (CausalMask) { + // BAND Matrix + // 1) Load K (performed inside mmaQK) + // 2) Create Tensor S + Tensor tSr = make_tensor( + Shape, Int, Int>{}); + clear(tSr); + // 3) Perform GEMM S = Q*K + collective_mma.mmaQK(tSr, gQ, gK(_, _, kv_splits_new - 1, _), tSr, + ceil_div(head_size_qk, QK_BLK_K), mainloop_params, + false, q_scale_val, k_scale_val); + // we only need one block ahead, there is enough gap to prefetch it + // while doing softmax. because the gap between the two MMA is big, + // prefetching it the same way as cutlass K matrix does not make sense + for (int i = 0; i < size<1>(pVgV); i++) { + prefetch(tiled_prefetch_v, pVgV(_, i, _, kv_splits_new - 1)); + } + // mask the elements of each tile where j > i + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id + (kv_splits_new - 1) * QK_BLK_N; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; + n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++, row_idx++) { // 8 + if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax((kv_splits - 1) == 0, tSr, max_reg, sum_reg, out_reg); + collective_mma.template mmaPV(out_reg, tSr, + gV(_, _, kv_splits_new - 1), + out_reg, mainloop_params, false, v_scale_val); + } + + + // Epilogue + auto epilogue_params = + CollectiveEpilogue::template get_updated_copies( + params.epilogue, params.problem_shape, sequence_length_shape, + batch_coord, q_head_coord); + CollectiveEpilogue epilogue{epilogue_params, shared_storage.epilogue}; + auto blk_coord_mnkl = make_coord(blk_m_coord, blk_n_coord, _, 0); + epilogue(params.problem_shape, sequence_length_shape, blk_coord_mnkl, + out_reg, max_reg, sum_reg); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::flash_attention::kernel diff --git a/benchmarks/flash_attention/flash_attention_decode/benchmark_runner.hpp b/benchmarks/flash_attention/flash_attention_decode/benchmark_runner.hpp index 2784dbb859..37cf8566ea 100644 --- a/benchmarks/flash_attention/flash_attention_decode/benchmark_runner.hpp +++ b/benchmarks/flash_attention/flash_attention_decode/benchmark_runner.hpp @@ -191,9 +191,9 @@ template struct BenchmarkRunnerFMHADecode { int max_seq_len_q = static_cast(get<3>(problem_size)); int max_seq_len_kv = static_cast(get<4>(problem_size)); int max_seq_len_kv_cache = static_cast(get<5>(problem_size)); - get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; - get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); diff --git a/benchmarks/flash_attention/flash_attention_prefill/benchmark_runner.hpp b/benchmarks/flash_attention/flash_attention_prefill/benchmark_runner.hpp index d7d60d71fd..66b19bc72d 100644 --- a/benchmarks/flash_attention/flash_attention_prefill/benchmark_runner.hpp +++ b/benchmarks/flash_attention/flash_attention_prefill/benchmark_runner.hpp @@ -165,8 +165,8 @@ template struct BenchmarkRunnerFMHA { if constexpr (isVarLen) { int max_seq_len_q = static_cast(get<3>(problem_size)); int max_seq_len_kv = static_cast(get<4>(problem_size)); - get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,5,6>(problem_size); diff --git a/benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp b/benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp index 801630bd49..75b05ae253 100644 --- a/benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp +++ b/benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp @@ -176,9 +176,9 @@ template struct BenchmarkRunnerFMHA { int max_seq_len_q = static_cast(get<3>(problem_size)); int max_seq_len_kv = static_cast(get<4>(problem_size)); int max_seq_len_kv_cache = static_cast(get<5>(problem_size)); - get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; - get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); diff --git a/examples/06_bmg_flash_attention/06_bmg_chunk_prefill.cpp b/examples/06_bmg_flash_attention/06_bmg_chunk_prefill.cpp new file mode 100644 index 0000000000..e3e3c2d9cb --- /dev/null +++ b/examples/06_bmg_flash_attention/06_bmg_chunk_prefill.cpp @@ -0,0 +1,116 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Flash Attention V2 Prefill for Intel BMG + + This example constructs and executes a Flash Attention Prefill with KV cache on Intel BMG. The + definition of the GEMM, options etc for this example are defined in the associated + bmg_flash_attn_cachedKV_runner.hpp header file. + + See https://arxiv.org/pdf/2307.08691 for details of Flash Attention V2 algorithm + + To run this example: + $ ./examples/sycl/06_bmg_flash_attention_cachedKV/06_bmg_prefill_attention_cachedKV --seq_len_qo=512 + --seq_len_kv=512 --seq_len_kv_cache=512 --head_size_vo=128 --head_size_qk=128 + + Causal masking of the first matrix multiplication is supported (`--is_causal`) + + To build & run this example (from your build dir): + + $ ninja 06_bmg_prefill_attention_cachedKV + $ ./examples/sycl/06_bmg_flash_attention_cachedKV/06_bmg_prefill_attention_cachedKV + + Call with `--help` for information about available options +*/ + +#include "bmg_flash_chunk_prefill_runner.hpp" + +int main(int argc, const char **argv) { + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // Define the work-group tile shape depending on the head-size of the second matmul + // Shape<_SequenceLenthOutputBLOCK, _HeadSizeout(NV), SequenceLengthKVBLOCK_KN/KV, HeadSizeQKBLOCK_KQK, HEADSIZEOutSlicerBlock> + // +#if !defined(HEAD_DIM) + std::cerr << "HEAD_DIM must be defined" << std::endl; + return -1; +#endif + if (options.head_size_vo != HEAD_DIM) { + std::cerr << "head_size_vo must be " << HEAD_DIM << ", but got " << options.head_size_vo << std::endl; + return -1; + } + + constexpr int PipelineStages = 2; +#if HEAD_DIM == 64 + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _64, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#elif HEAD_DIM == 96 + using ShapeQK = Shape<_128, _64, _32>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _96, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#elif HEAD_DIM == 128 + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _128, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#elif HEAD_DIM == 192 + using ShapeQK = Shape<_256, _64, _64>; + using ShapePV = Shape<_256, _32, _64>; + using ShapeOutPut = Shape<_256, _192, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#endif + if (options.is_causal) { + FMHAConfig::run(options); + } else if (options.is_local_mask) { + FMHAConfig::run(options); + } else { + FMHAConfig::run(options); + } +} diff --git a/examples/06_bmg_flash_attention/06_bmg_chunk_prefill_fp8.cpp b/examples/06_bmg_flash_attention/06_bmg_chunk_prefill_fp8.cpp new file mode 100644 index 0000000000..c3336d633e --- /dev/null +++ b/examples/06_bmg_flash_attention/06_bmg_chunk_prefill_fp8.cpp @@ -0,0 +1,147 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief fp8 Chunk Prefill for Intel BMG + + This example constructs and executes a FP8 Flash Attention Chunk Prefill on Intel BMG. The + definition of the GEMM, options etc for this example are defined in the associated + bmg_flash_chunk_prefill_runner.hpp header file. + + See https://arxiv.org/pdf/2307.08691 for details of Flash Attention V2 algorithm + + To build & run this example (from your build dir): + + $ ninja 06_bmg_chunk_prefill_fp8_hdim128 + $ ./examples/06_bmg_flash_attention/06_bmg_chunk_prefill_fp8_hdim128 + + Call with `--help` for information about available options +*/ + +#include "bmg_flash_chunk_prefill_runner.hpp" + +int main(int argc, const char **argv) { + // + // Parse options + // + + Options options; + + // Pre-parse command line to get batch and head counts for sizing the default vectors. + cutlass::CommandLine cmd(argc, argv); + int batch = 32; + int num_heads_q = 16; + int num_heads_kv = 0; + cmd.get_cmd_line_argument("batch", batch, 32); + cmd.get_cmd_line_argument("num_heads_q", num_heads_q, 16); + cmd.get_cmd_line_argument("num_heads_kv", num_heads_kv, num_heads_q); + + // Set default scale values for this specific FP8 test using the correct 2D vector shapes. + options.q_scale.assign(batch, std::vector(num_heads_q, 1.5f)); + options.k_scale.assign(batch, std::vector(num_heads_kv, 2.5f)); + options.v_scale.assign(batch, std::vector(num_heads_kv, 1.9f)); + + // Now, parse all arguments. The defaults we just set will be preserved unless + // the user has provided their own scale values via the command line. + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if !defined(HEAD_DIM) + std::cerr << "HEAD_DIM must be defined" << std::endl; + return -1; +#endif + if (options.head_size_vo != HEAD_DIM) { + std::cerr << "head_size_vo must be " << HEAD_DIM << ", but got " << options.head_size_vo << std::endl; + return -1; + } + + // ================================================================================================= + // FP8 Type Definitions + // ================================================================================================= + using ElementInputQ = cutlass::float_e5m2_t; // <- data type of elements in input matrix A + using ElementInputKV = cutlass::float_e5m2_t; // <- data type of elements in input matrix B + using MMAOperation = XE_8x16x16_F32BF16BF16F32_TT; + using GmemTiledCopyQ = XE_2D_U8x8x32_LD_N; + using GmemTiledCopyK = XE_2D_U8x16x16_LD_T; // _T designates a transposed block load operation + using GmemTiledCopyV = XE_2D_U8x32x32_LD_V; + + constexpr int PipelineStages = 2; + + // ================================================================================================= + // Tile Shape Definitions + // ================================================================================================= +#if HEAD_DIM == 64 + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _64, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#elif HEAD_DIM == 96 + using ShapeQK = Shape<_128, _64, _32>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _96, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#elif HEAD_DIM == 128 + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _128, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#elif HEAD_DIM == 192 + using ShapeQK = Shape<_256, _64, _64>; + using ShapePV = Shape<_256, _32, _64>; + using ShapeOutPut = Shape<_256, _192, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#endif + + // ================================================================================================= + // Kernel Launch + // ================================================================================================= + if (options.is_causal) { + FMHAConfig::run(options); + } else if (options.is_local_mask) { + FMHAConfig::run(options); + } else { + FMHAConfig::run(options); + } +} diff --git a/examples/06_bmg_flash_attention/CMakeLists.txt b/examples/06_bmg_flash_attention/CMakeLists.txt index 39752da4ed..e73aa4131f 100644 --- a/examples/06_bmg_flash_attention/CMakeLists.txt +++ b/examples/06_bmg_flash_attention/CMakeLists.txt @@ -1,4 +1,5 @@ # Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. +# Copyright (C) 2025 Intel Corporation, All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -63,6 +64,19 @@ foreach(HEAD_DIM 64 96 128 192) cutlass_example_add_executable( 06_bmg_decode_attention_fp8_hdim${HEAD_DIM} 06_bmg_decode_attention_fp8.cpp + ) + + cutlass_example_add_executable( + 06_bmg_chunk_prefill_hdim${HEAD_DIM} + 06_bmg_chunk_prefill.cpp + TEST_COMMAND_OPTIONS + TEST_NO_PAGED + TEST_PAGED + ) + + cutlass_example_add_executable( + 06_bmg_chunk_prefill_fp8_hdim${HEAD_DIM} + 06_bmg_chunk_prefill_fp8.cpp TEST_COMMAND_OPTIONS TEST_NO_PAGED TEST_PAGED @@ -72,4 +86,6 @@ foreach(HEAD_DIM 64 96 128 192) target_compile_definitions(06_bmg_decode_attention_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) target_compile_definitions(06_bmg_prefill_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) target_compile_definitions(06_bmg_decode_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) + target_compile_definitions(06_bmg_chunk_prefill_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) + target_compile_definitions(06_bmg_chunk_prefill_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) endforeach() diff --git a/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp index 6bb7ad487f..74cda67a5f 100644 --- a/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp +++ b/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp @@ -222,9 +222,9 @@ template struct ExampleRunner { int max_seq_len_q = static_cast(get<3>(problem_size)); int max_seq_len_kv = static_cast(get<4>(problem_size)); int max_seq_len_kv_cache = static_cast(get<5>(problem_size)); - get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; - get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); diff --git a/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp index f87cc8af2f..5adb778722 100644 --- a/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp +++ b/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp @@ -214,9 +214,9 @@ template struct ExampleRunner { int max_seq_len_q = static_cast(get<3>(problem_size)); int max_seq_len_kv = static_cast(get<4>(problem_size)); int max_seq_len_kv_cache = static_cast(get<5>(problem_size)); - get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; - get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); diff --git a/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp index 264eae22e6..5e0086976e 100644 --- a/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp +++ b/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp @@ -193,8 +193,8 @@ template struct ExampleRunner { if constexpr (isVarLen) { int max_seq_len_q = static_cast(get<3>(problem_size)); int max_seq_len_kv = static_cast(get<4>(problem_size)); - get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,5,6>(problem_size); diff --git a/examples/06_bmg_flash_attention/bmg_flash_chunk_prefill_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_chunk_prefill_runner.hpp new file mode 100644 index 0000000000..b1ecb8693b --- /dev/null +++ b/examples/06_bmg_flash_attention/bmg_flash_chunk_prefill_runner.hpp @@ -0,0 +1,1019 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "flash_attention_v2/collective/fmha_fusion.hpp" +#include "flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "flash_attention_v2/kernel/xe_chunk_prefill.hpp" +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp" +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/sycl_event_manager.hpp" +#include "cutlass/fp8_to_fp16.h" + +#include +#include + +#include "helper.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" + +using namespace cute; + +// Helper to check for FP8 types +template +constexpr bool is_fp8_v = std::is_same_v || std::is_same_v; + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool is_causal; + bool is_local_mask; + bool varlen = false; + bool use_paged_kv = false; + std::string scheduler; + + int batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, page_size, head_size_qk, head_size_vo, iterations, window_left, window_right; + float softmax_scale; + + // Add scale factors to options, now per-batch, per-head + std::vector> q_scale; + std::vector> k_scale; + std::vector> v_scale; + + Options() + : help(false), error(false), is_causal(false), is_local_mask(false), varlen(false), use_paged_kv(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(512), head_size_qk(128), + seq_len_kv(512), seq_len_kv_cache(512), page_size(128), head_size_vo(128), iterations(100), window_left(-1), window_right(-1), softmax_scale(1.f), + scheduler("Individual") {} + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("is_causal")) { + is_causal = true; + } + + if (cmd.check_cmd_line_flag("varlen")) { + varlen = true; + } + + cmd.get_cmd_line_argument("scheduler", scheduler, std::string("Individual")); + + cmd.get_cmd_line_argument("batch", batch, 32); + cmd.get_cmd_line_argument("num_heads_q", num_heads_q, 16); + cmd.get_cmd_line_argument("num_heads_kv", num_heads_kv, num_heads_q); + cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, 512); + cmd.get_cmd_line_argument("seq_len_kv", seq_len_kv, seq_len_qo); + cmd.get_cmd_line_argument("seq_len_kv_cache", seq_len_kv_cache, 512); + cmd.get_cmd_line_argument("head_size_vo", head_size_vo, HEAD_DIM); + cmd.get_cmd_line_argument("head_size_qk", head_size_qk, head_size_vo); + cmd.get_cmd_line_argument("window_left", window_left, -1); + cmd.get_cmd_line_argument("window_right", window_right, -1); + cmd.get_cmd_line_argument("iterations", iterations, 100); + + // Add command line parsing for scale factors + float q_scale_val = 1.0f, k_scale_val = 1.0f, v_scale_val = 1.0f; + bool q_scale_provided = cmd.check_cmd_line_flag("q_scale"); + bool k_scale_provided = cmd.check_cmd_line_flag("k_scale"); + bool v_scale_provided = cmd.check_cmd_line_flag("v_scale"); + + cmd.get_cmd_line_argument("q_scale", q_scale_val, 1.0f); + cmd.get_cmd_line_argument("k_scale", k_scale_val, 1.0f); + cmd.get_cmd_line_argument("v_scale", v_scale_val, 1.0f); + + // If scale vectors are uninitialized or have the wrong size, resize and assign. + // This allows pre-setting them with specific defaults before calling parse. + if (q_scale_provided || q_scale.size() != batch || (batch > 0 && q_scale[0].size() != num_heads_q)) { + q_scale.assign(batch, std::vector(num_heads_q, q_scale_val)); + } + if (k_scale_provided || k_scale.size() != batch || (batch > 0 && k_scale[0].size() != num_heads_kv)) { + k_scale.assign(batch, std::vector(num_heads_kv, k_scale_val)); + } + if (v_scale_provided || v_scale.size() != batch || (batch > 0 && v_scale[0].size() != num_heads_kv)) { + v_scale.assign(batch, std::vector(num_heads_kv, v_scale_val)); + } + + + if (cmd.check_cmd_line_flag("use_paged_kv")) { + use_paged_kv = true; + cmd.get_cmd_line_argument("page_size", page_size, 128); + seq_len_kv = 0; // seq_len_kv is not used when use paged kv + if (page_size % 128 != 0) { + std::cerr << "Invalid: page_size must be a multiple of 128" << std::endl; + return; + } + if (seq_len_kv_cache % page_size != 0) { + std::cerr << "Invalid: seq_len_kv_cache must be divisible by page_size" << std::endl; + return; + } + } + if (window_left > -1 && window_right > -1) { + is_local_mask = true; + } + softmax_scale = 1 / sqrt(static_cast(head_size_qk)); + } + + /// Prints the usage statement. + std::ostream &print_usage(std::ostream &out) const { + + out << "BMG Flash Attention v2 Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --is_causal Apply Causal Mask to the output of first Matmul\n" + << " --window_left= Set the left borders of the window, If set to -1, calculate all seq_len\n" + << " --window_right= Set the left borders of the window, If set to -1, calculate all seq_len\n" + << " --varlen Enable variable sequence length\n" + << " --scheduler=\"Value\" Choose between Individual or Persistent Scheduler\n" + << " --batch= Sets the Batch Size of the Multi-Head Self Attention module\n" + << " --num_heads_q= Sets the Number of Attention Heads for Key-Value pair the Multi-Head Self Attention module\n" + << " --num_heads_kv= Sets the Number of Attention Heads for Query input in the Multi-Head Self Attention module\n" + << " --seq_len_qo= Sets the Sequence length of the Query input in Multi-Head Self Attention module\n" + << " --seq_len_kv= Sets the Sequence length of the Key-Value pair in Multi-Head Self Attention module\n" + << " --seq_len_kv_cache= Sets the Sequence length of the cached Key-Value pair in Multi-Head Self Attention module\n" + << " --use_paged_kv Use paged (non-contiguous) KV cache. Default is contiguous KV Cache\n" + << " --page_size= Block size for paged KV cache. Default is 128\n" + << " --head_size_qk= Sets the Attention Head dimension of the 1st Matrix Multiplication in Multi-Head Self Attention module\n" + << " --head_size_vo= Sets the Attention Head dimension of the 2nd Matrix Multiplication in Multi-Head Self Attention module\n" + << " --iterations= Iterations\n\n" + << " --q_scale= FP8 quantization scale for Q\n" + << " --k_scale= FP8 quantization scale for K\n" + << " --v_scale= FP8 quantization scale for V\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Flash Attention takes 3 input matrices: (K)eys, (Q)ueries and (V)alues. +using LayoutQ = cutlass::layout::RowMajor; +using LayoutK = cutlass::layout::ColumnMajor; +using LayoutV = cutlass::layout::RowMajor; +using LayoutO = cutlass::layout::RowMajor; + +template struct ExampleRunner { + + using StrideQ = typename FMHAChunkPrefillKernel::StrideQ; + using StrideK = typename FMHAChunkPrefillKernel::StrideK; + using StrideV = typename FMHAChunkPrefillKernel::StrideV; + using StrideO = typename FMHAChunkPrefillKernel::StrideO; + + using ElementQ = typename FMHAChunkPrefillKernel::ElementQ; + using ElementK = typename FMHAChunkPrefillKernel::ElementK; + using ElementV = typename FMHAChunkPrefillKernel::ElementV; + using ElementAcc = typename FMHAChunkPrefillKernel::ElementAccumulator; + + using CollectiveEpilogue = typename FMHAChunkPrefillKernel::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename FMHAChunkPrefillKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideK stride_K_cache; + StrideV stride_V_cache; + StrideO stride_O; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_Q; + cutlass::DeviceAllocation block_K; + cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_K_cache; + cutlass::DeviceAllocation block_V_cache; + cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_ref_O; + + // Add device allocations for scale factors + cutlass::DeviceAllocation block_q_scale; + cutlass::DeviceAllocation block_k_scale; + cutlass::DeviceAllocation block_v_scale; + + std::vector cumulative_seqlen_q; + std::vector cumulative_seqlen_kv; + std::vector cumulative_seqlen_kv_cache; + cutlass::DeviceAllocation device_cumulative_seqlen_q; + cutlass::DeviceAllocation device_cumulative_seqlen_kv; + cutlass::DeviceAllocation device_cumulative_seqlen_kv_cache; + + struct PagedKVParams { + cutlass::DeviceAllocation page_table; + int page_size = 0; + cutlass::DeviceAllocation num_pages_per_seq; + }; + PagedKVParams paged_kv_cache; + + // + // Methods + // + +template +void run_conversion_kernel(SrcType* src_ptr_in, DstType* dst_ptr_in, int64_t num_elements, + const float* scales, int batch, int64_t seq_len, int64_t num_heads, int64_t head_size) { + sycl::queue queue = compat::get_default_queue(); + int64_t num_threads = 256; + int64_t num_blocks = ceil_div(num_elements, num_threads); + + queue.submit([&](sycl::handler& cgh) { + // Correctly cast the source pointer to uint8_t* for the kernel + uint8_t* src_ptr = reinterpret_cast(src_ptr_in); + DstType* dst_ptr = dst_ptr_in; + cgh.parallel_for(sycl::nd_range<1>(num_blocks * num_threads, num_threads), [=](sycl::nd_item<1> item) { + int64_t idx = item.get_global_id(0); + if (idx < num_elements) { + // Use the appropriate scale for this element based on its head + // Assuming tensor layout is [seq_len, num_heads, head_size] + int64_t element_offset = idx; + int64_t head_size_offset = element_offset % head_size; + element_offset /= head_size; + int64_t head_offset = element_offset % num_heads; + int64_t seq_offset = element_offset / num_heads; + + // Get the correct scale for this head + float scale = scales[batch * num_heads + head_offset]; + + // Create tensors with the correct types for convert_and_descale + auto src_tensor = make_tensor(src_ptr + idx, make_shape(1)); + auto dst_tensor = make_tensor(dst_ptr + idx, make_shape(1)); + convert_and_descale(src_tensor, dst_tensor, scale); + } + }); + }); +} + +bool verify(ProblemShapeType problem_size, Options options) { + std::vector host_O(block_ref_O.size()); + + auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); + int seq_len_qo, seq_len_kv, seq_len_kv_cache; + + int offset_q = 0; + int offset_k = 0; + int offset_v = 0; + int offset_k_cache = 0; + int offset_v_cache = 0; + int offset_o = 0; + + using namespace cutlass; + using RefElement = bfloat16_t; + DeviceAllocation block_Q_ref, block_K_ref, block_V_ref; + + // loop over the batch dimension to compute the output + // to avoid the risk of running out of device memory + int q_group_size = num_heads_q / num_heads_kv; + for (int b = 0; b < batch; b++) { + if constexpr (isVarLen) { + auto logical_problem_shape = cutlass::fmha::collective::apply_variable_length(problem_size, b); + seq_len_qo = get<3>(logical_problem_shape); + seq_len_kv = get<4>(logical_problem_shape); + seq_len_kv_cache = get<5>(logical_problem_shape); + } else { + seq_len_qo = get<3>(problem_size); + seq_len_kv = get<4>(problem_size); + seq_len_kv_cache = get<5>(problem_size); + } + + ElementQ* q_ptr_orig = block_Q.get() + offset_q; + ElementK* k_ptr_orig; + ElementV* v_ptr_orig; + + void* q_ptr = q_ptr_orig; + void* k_ptr; + void* v_ptr; + + int seq_len_kv_total = seq_len_kv_cache + seq_len_kv; + cutlass::DeviceAllocation block_K_concat; + cutlass::DeviceAllocation block_V_concat; + + if (seq_len_kv_cache > 0) { // use_kv_cache + if (options.use_paged_kv) { + int num_pages = paged_kv_cache.page_table.size(); + std::vector host_page_table(paged_kv_cache.page_table.size()); + std::vector host_num_pages_per_seq(paged_kv_cache.num_pages_per_seq.size()); + compat::memcpy(host_page_table.data(), paged_kv_cache.page_table.get(), paged_kv_cache.page_table.size()); + compat::memcpy(host_num_pages_per_seq.data(), paged_kv_cache.num_pages_per_seq.get(), paged_kv_cache.num_pages_per_seq.size()); + + int curr_batch_pages = isVarLen ? host_num_pages_per_seq[b + 1] - host_num_pages_per_seq[b] : ceil_div(seq_len_kv_cache, paged_kv_cache.page_size); + int batch_offset = isVarLen ? host_num_pages_per_seq[b] : b * curr_batch_pages; + block_K_concat.reset((seq_len_kv + curr_batch_pages * paged_kv_cache.page_size) * num_heads_kv * head_size_qk); + block_V_concat.reset((seq_len_kv + curr_batch_pages * paged_kv_cache.page_size) * num_heads_kv * head_size_vo); + + for (int p = 0; p < curr_batch_pages; p++) { + int page_idx = host_page_table[batch_offset + p]; + // copy the page from KV cache to the concatenated buffer + compat::memcpy( + block_K_concat.get() + p * paged_kv_cache.page_size * num_heads_kv * head_size_qk, + block_K_cache.get() + page_idx * paged_kv_cache.page_size * num_heads_kv * head_size_qk, + paged_kv_cache.page_size * num_heads_kv * head_size_qk + ); + compat::memcpy( + block_V_concat.get() + p * paged_kv_cache.page_size * num_heads_kv * head_size_vo, + block_V_cache.get() + page_idx * paged_kv_cache.page_size * num_heads_kv * head_size_vo, + paged_kv_cache.page_size * num_heads_kv * head_size_vo + ); + } + if (seq_len_kv > 0) { + compat::memcpy( + // block_K_concat.get() + curr_batch_pages * paged_kv_cache.page_sze * num_heads_kv *head_size_qk, + block_K_concat.get() + seq_len_kv_cache * num_heads_kv * head_size_qk, + block_K.get() + offset_k, + seq_len_kv * num_heads_kv * head_size_qk + ); + compat::memcpy( + block_V_concat.get() + seq_len_kv_cache * num_heads_kv * head_size_vo, + block_V.get() + offset_v, + seq_len_kv * num_heads_kv * head_size_vo + ); + } + compat::wait(); + } else { + block_K_concat.reset(seq_len_kv_total * num_heads_kv * head_size_qk); + block_V_concat.reset(seq_len_kv_total * num_heads_kv * head_size_vo); + // Concatenate K_cache and K + compat::memcpy( + block_K_concat.get(), + block_K_cache.get() + offset_k_cache, + seq_len_kv_cache * num_heads_kv * head_size_qk + ); + compat::memcpy( + block_K_concat.get() + seq_len_kv_cache * num_heads_kv * head_size_qk, + block_K.get() + offset_k, + seq_len_kv * num_heads_kv * head_size_qk + ); + // Concatenate V_cache and V + compat::memcpy( + block_V_concat.get(), + block_V_cache.get() + offset_v_cache, + seq_len_kv_cache * num_heads_kv * head_size_vo + ); + compat::memcpy( + block_V_concat.get() + seq_len_kv_cache * num_heads_kv * head_size_vo, + block_V.get() + offset_v, + seq_len_kv * num_heads_kv * head_size_vo + ); + // compat::wait(); + } + k_ptr_orig = block_K_concat.get(); + v_ptr_orig = block_V_concat.get(); + } else { + k_ptr_orig = block_K.get() + offset_k; + v_ptr_orig = block_V.get() + offset_v; + } + + k_ptr = k_ptr_orig; + v_ptr = v_ptr_orig; + + if constexpr (is_fp8_v) { + block_Q_ref.reset(seq_len_qo * num_heads_q * head_size_qk); + run_conversion_kernel( + q_ptr_orig, block_Q_ref.get(), block_Q_ref.size(), + block_q_scale.get(), b, seq_len_qo, num_heads_q, head_size_qk); + q_ptr = block_Q_ref.get(); + } + if constexpr (is_fp8_v) { + block_K_ref.reset(seq_len_kv_total * num_heads_kv * head_size_qk); + run_conversion_kernel( + k_ptr_orig, block_K_ref.get(), block_K_ref.size(), + block_k_scale.get(), b, seq_len_kv_total, num_heads_kv, head_size_qk); + k_ptr = block_K_ref.get(); + } + if constexpr (is_fp8_v) { + block_V_ref.reset(seq_len_kv_total * num_heads_kv * head_size_vo); + run_conversion_kernel( + v_ptr_orig, block_V_ref.get(), block_V_ref.size(), + block_v_scale.get(), b, seq_len_kv_total, num_heads_kv, head_size_vo); + v_ptr = block_V_ref.get(); + } + compat::wait(); + + for (int q_group = 0; q_group < num_heads_q / q_group_size; q_group++) { + for (int q_head = 0; q_head < q_group_size; q_head++) { + cutlass::DeviceAllocation block_S; + block_S.reset(seq_len_qo * seq_len_kv_total); + + int head_offset_q = (q_group * q_group_size + q_head) * head_size_qk; + int head_offset_k = q_group * head_size_qk; + int head_offset_v = q_group * head_size_vo; + + cutlass::TensorRef ref_Q_head(reinterpret_cast(q_ptr) + head_offset_q, LayoutQ(num_heads_q * head_size_qk)); + cutlass::TensorRef ref_K_head(reinterpret_cast(k_ptr) + head_offset_k, LayoutK(num_heads_kv * head_size_qk)); + cutlass::TensorRef ref_V_head(reinterpret_cast(v_ptr) + head_offset_v, LayoutV(num_heads_kv * head_size_vo)); + cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); + + cutlass::reference::device::GemmComplex( + {seq_len_qo, seq_len_kv_total, head_size_qk}, + ElementAccumulator{1}, + ref_Q_head, + cutlass::ComplexTransform::kNone, + ref_K_head, + cutlass::ComplexTransform::kNone, + ElementAccumulator{0}, + ref_S, + ref_S + ); + compat::wait(); + + std::vector host_S(block_S.size()); + compat::memcpy(host_S.data(), block_S.get(), host_S.size()); + + // delete this memory as it is no longer needed + block_S.reset(); + auto offset = cute::min(seq_len_qo, seq_len_kv); + auto discard_seq_coord = seq_len_qo - offset; + auto full_tile_offset = seq_len_kv - offset; + // apply mask to S + for (int row = 0; row < seq_len_qo; row++) { + for (int col = 0; col < seq_len_kv_total; col++) { + // causal mask + if (options.is_causal && (col - full_tile_offset > row + seq_len_kv_cache - discard_seq_coord)) { + host_S[col + row * seq_len_kv_total] = ElementAccumulator{-INFINITY}; + } + // sliding window mask + int col_ref = seq_len_kv_cache + seq_len_kv - seq_len_qo; + bool left_mask = col < cute::max(0, col_ref + row - options.window_left); + bool right_mask = col > cute::min(seq_len_kv_total, col_ref + row + options.window_right); + if (options.is_local_mask && (left_mask || right_mask)) { + host_S[col + row * seq_len_kv_total] = ElementAccumulator{-INFINITY}; + } + } + } + + // compute max element per row of S + std::vector max_vec(seq_len_qo, ElementAccumulator{-INFINITY}); + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv_total; + int max_idx = row; + max_vec[max_idx] = host_S[idx++]; + for (int col = 1; col < seq_len_kv_total; col++, idx++) { + if (max_vec[max_idx] < host_S[idx]) + max_vec[max_idx] = host_S[idx]; + } + } + // compute exp of S + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv_total; + int max_idx = row; + for (int col = 0; col < seq_len_kv_total; col++, idx++) { + host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / options.softmax_scale); + } + } + + // compute sum per row of S + std::vector sum_vec(seq_len_qo, ElementAccumulator{0}); + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv_total; + int sum_idx = row; + for (int col = 0; col < seq_len_kv_total; col++, idx++) { + sum_vec[sum_idx] += host_S[idx]; + } + + // scale each row with the sum to compute softmax + idx = row * seq_len_kv_total; + sum_idx = row; + int col_ref = seq_len_kv_cache + seq_len_kv - seq_len_qo; + for (int col = 0; col < seq_len_kv_total; col++, idx++) { + if (options.is_causal && row < discard_seq_coord) { + host_S[idx] = 0; + } else if (options.is_local_mask && (col < cute::max(0, col_ref + row - options.window_left) + || col > cute::min(seq_len_kv_total, col_ref + row + options.window_right))) { + host_S[idx] = 0; + } else { + host_S[idx] /= sum_vec[sum_idx]; + } + } + } + std::vector host_P(host_S.size()); + for (int p = 0; p < host_P.size(); p++) + host_P[p] = static_cast(host_S[p]); + + cutlass::DeviceAllocation block_P; + block_P.reset(host_P.size()); + + compat::memcpy(block_P.get(), host_P.data(), host_P.size()); + + cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); + + cutlass::DeviceAllocation block_acc; + block_acc.reset(seq_len_qo * head_size_vo); + cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo})); + + cutlass::reference::device::GemmComplex( + {seq_len_qo, head_size_vo, seq_len_kv_total}, + ElementAccumulator{1}, + ref_P, + cutlass::ComplexTransform::kNone, + ref_V_head, + cutlass::ComplexTransform::kNone, + ElementAccumulator{0}, + ref_acc, + ref_acc + ); + + compat::wait(); + // delete this memory as it is no longer needed + block_P.reset(); + + std::vector vec_acc(block_acc.size()); + compat::memcpy(vec_acc.data(), block_acc.get(), vec_acc.size()); + + // delete this memory as it is no longer needed + block_acc.reset(); + for (int seq = 0; seq < seq_len_qo; seq++) { + for (int hvo = 0; hvo < head_size_vo; hvo++) { + int idx = offset_o + seq * num_heads_q * head_size_vo + (q_group * q_group_size + q_head) * head_size_vo + hvo; + host_O[idx] = static_cast(vec_acc[seq * head_size_vo + hvo]); + } + } + } // end of q_group loop + } // end of q_head loop + offset_q += seq_len_qo * num_heads_q * head_size_qk; + offset_k += seq_len_kv * num_heads_kv * head_size_qk; + offset_v += seq_len_kv * num_heads_kv * head_size_vo; + offset_k_cache += seq_len_kv_cache * num_heads_kv * head_size_qk; + offset_v_cache += seq_len_kv_cache * num_heads_kv * head_size_vo; + offset_o += seq_len_qo * num_heads_q * head_size_vo; + } // end of batch loop + + compat::wait(); + compat::memcpy(block_ref_O.get(), host_O.data(), host_O.size()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(), + block_O.size(), ElementOutput{0.5}, ElementOutput{0.5}); + + return passed; + } + + template + auto initialize_varlen(const ProblemShape& problem_size) { + int num_batches = get<0>(problem_size); + int seq_len_kv_cache = get<5>(problem_size); + + // generate Q as --b times + // gaussian (--Q, --Q / 2) sampled positive + // track cumulative + std::mt19937 rng(0x202305151552ull); + std::normal_distribution dist_q(get<3>(problem_size), get<3>(problem_size) / 2); + std::normal_distribution dist_kv(get<4>(problem_size), get<4>(problem_size) / 2); + std::normal_distribution dist_kv_cache(get<5>(problem_size), get<5>(problem_size) / 2); + + // Use Cacheline Size to calculate alignment + constexpr int cacheline_bytes = 64; + constexpr int AlignmentQ = cacheline_bytes / sizeof(ElementQ); // Alignment of Q matrix in units of elements + constexpr int AlignmentKV = cacheline_bytes / sizeof(ElementK); // Alignment of Kand V matrix in units of elements + + auto generate_positive_int = [](auto& dist, auto& gen) { + int result = 0; + do { + result = static_cast(dist(gen)); + } while (result <= 0); + return result; + }; + + cumulative_seqlen_q = {0}; + cumulative_seqlen_kv = {0}; + cumulative_seqlen_kv_cache = {0}; + + int total_seqlen_q = 0; + int total_seqlen_kv = 0; + int total_seqlen_kv_cache = 0; + int max_seqlen_q = 0; + int max_seqlen_kv = 0; + int max_seqlen_kv_cache = 0; + + for (int i = 0; i < num_batches; i++) { + int seqlen_q = cutlass::round_up(generate_positive_int(dist_q, rng), AlignmentQ); + int seqlen_kv = cute::get<4>(problem_size) == 0 ? 0 : cutlass::round_up(generate_positive_int(dist_kv, rng), AlignmentKV); + int seqlen_kv_cache = cute::get<5>(problem_size) == 0 ? 0 : cutlass::round_up(generate_positive_int(dist_kv_cache, rng), AlignmentKV); + + total_seqlen_q += seqlen_q; + total_seqlen_kv += seqlen_kv; + total_seqlen_kv_cache += seqlen_kv_cache; + + max_seqlen_q = std::max(max_seqlen_q, seqlen_q); + max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv); + max_seqlen_kv_cache = std::max(max_seqlen_kv_cache, seqlen_kv_cache); + + cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q); + cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv); + cumulative_seqlen_kv_cache.push_back(cumulative_seqlen_kv_cache.back() + seqlen_kv_cache); + } + + ProblemShape problem_size_for_init = problem_size; + get<0>(problem_size_for_init) = 1; + get<3>(problem_size_for_init) = total_seqlen_q; + get<4>(problem_size_for_init) = total_seqlen_kv; + get<5>(problem_size_for_init) = total_seqlen_kv_cache; + + ProblemShapeType problem_size_for_launch; + + get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_q, total_seqlen_q}; + get<4>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_kv, total_seqlen_kv}; + get<5>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_kv_cache, total_seqlen_kv_cache}; + get<6>(problem_size_for_launch) = get<6>(problem_size); + get<7>(problem_size_for_launch) = get<7>(problem_size); + get<0>(problem_size_for_launch) = get<0>(problem_size); + get<1>(problem_size_for_launch) = get<1>(problem_size); + get<2>(problem_size_for_launch) = get<2>(problem_size); + + + return cute::make_tuple(problem_size_for_init, problem_size_for_launch); + } + + /// Initialize operands to be used in the GEMM and reference GEMM + ProblemShapeType initialize(const Options &options) { + auto problem_shape_in = + cute::make_tuple(options.batch, options.num_heads_q, options.num_heads_kv, options.seq_len_qo, options.seq_len_kv, options.seq_len_kv_cache, options.head_size_qk, options.head_size_vo); + + ProblemShapeType problem_shape; + decltype(problem_shape_in) problem_size; + + if constexpr (isVarLen) { + auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in); + problem_shape = problem_shape_launch; + problem_size = problem_shape_init; + } + else { + problem_size = problem_shape_in; + problem_shape = problem_shape_in; + } + + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_size; + + stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, num_heads_q * head_size_qk, batch)); + stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch)); + stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo * num_heads_kv, seq_len_kv, batch)); + + stride_K_cache = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch)); + stride_V_cache = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo * num_heads_kv, seq_len_kv_cache, batch)); + stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, num_heads_q * head_size_vo, batch)); + + block_Q.reset(batch * num_heads_q * seq_len_qo * head_size_qk); + block_K.reset(batch * num_heads_kv * seq_len_kv * head_size_qk); + block_V.reset(batch * num_heads_kv * seq_len_kv * head_size_vo); + if (!options.use_paged_kv) { + block_K_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_qk); + block_V_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_vo); + } + block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); + block_ref_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); + + // Initialize scale tensors + if constexpr (is_fp8_v) { + // Flatten the 2D host vector to a 1D host vector for copying to the device + std::vector q_scale_host, k_scale_host, v_scale_host; + q_scale_host.reserve(options.q_scale.size() * (options.q_scale.empty() ? 0 : options.q_scale[0].size())); + for (const auto& batch_scales : options.q_scale) { + q_scale_host.insert(q_scale_host.end(), batch_scales.begin(), batch_scales.end()); + } + + k_scale_host.reserve(options.k_scale.size() * (options.k_scale.empty() ? 0 : options.k_scale[0].size())); + for (const auto& batch_scales : options.k_scale) { + k_scale_host.insert(k_scale_host.end(), batch_scales.begin(), batch_scales.end()); + } + + v_scale_host.reserve(options.v_scale.size() * (options.v_scale.empty() ? 0 : options.v_scale[0].size())); + for (const auto& batch_scales : options.v_scale) { + v_scale_host.insert(v_scale_host.end(), batch_scales.begin(), batch_scales.end()); + } + + block_q_scale.reset(q_scale_host.size()); + block_k_scale.reset(k_scale_host.size()); + block_v_scale.reset(v_scale_host.size()); + + block_q_scale.copy_from_host(q_scale_host.data()); + block_k_scale.copy_from_host(k_scale_host.data()); + block_v_scale.copy_from_host(v_scale_host.data()); + } + + if (options.use_paged_kv) { + paged_kv_cache.page_size = options.page_size; + std::vector num_pages_per_seq{0}; + int num_pages = 0; + for(int b = 0; b < cute::get<0>(problem_shape); b++) { + int seq_len_cache = isVarLen ? cumulative_seqlen_kv_cache[b + 1] - cumulative_seqlen_kv_cache[b] : seq_len_kv_cache; + int pages_per_seq = ceil_div(seq_len_cache, paged_kv_cache.page_size); + num_pages_per_seq.push_back(num_pages_per_seq.back() + pages_per_seq); + num_pages += pages_per_seq; + } + paged_kv_cache.page_table.reset(num_pages); + + + // initialize block table with random mapping for non-contiguous layout + std::vector page_mapping(num_pages); + for (int b = 0; b < cute::get<0>(problem_shape); ++b) { + std::vector physical_pages(num_pages_per_seq[b + 1] - num_pages_per_seq[b]); + std::iota(physical_pages.begin(), physical_pages.end(), 0); + // shuffle physical pages + std::shuffle(physical_pages.begin(), physical_pages.end(), std::mt19937{ std::random_device{}() }); + for (int blk = 0; blk < physical_pages.size(); ++blk) { + int logical_idx = num_pages_per_seq[b] + blk; + page_mapping[logical_idx] = physical_pages[blk]; + } + } + compat::memcpy(paged_kv_cache.page_table.get(), page_mapping.data(), page_mapping.size() * sizeof(int)); + + paged_kv_cache.num_pages_per_seq.reset(num_pages_per_seq.size()); + compat::memcpy(paged_kv_cache.num_pages_per_seq.get(), num_pages_per_seq.data(), num_pages_per_seq.size() * sizeof(int)); + + block_K_cache.reset(num_pages * paged_kv_cache.page_size * num_heads_kv * head_size_qk); + block_V_cache.reset(num_pages * paged_kv_cache.page_size * num_heads_kv * head_size_vo); + } + + initialize_block(block_Q, seed + 2023); + initialize_block(block_K, seed + 2022); + initialize_block(block_V, seed + 2021); + initialize_block(block_K_cache, seed + 2024); + initialize_block(block_V_cache, seed + 2025); + + if (!cumulative_seqlen_q.empty()) { + device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); + device_cumulative_seqlen_q.copy_from_host( + cumulative_seqlen_q.data(), cumulative_seqlen_q.size()); + } + + if (!cumulative_seqlen_kv.empty()) { + device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); + device_cumulative_seqlen_kv.copy_from_host( + cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); + } + + if (!cumulative_seqlen_kv_cache.empty()) { + device_cumulative_seqlen_kv_cache.reset(cumulative_seqlen_kv_cache.size()); + device_cumulative_seqlen_kv_cache.copy_from_host( + cumulative_seqlen_kv_cache.data(), cumulative_seqlen_kv_cache.size()); + } + + if constexpr (isVarLen) { + get<3>(problem_shape).max_length = get<3>(problem_shape).max_length; + get<3>(problem_shape).total_length = get<3>(problem_shape).total_length; + get<3>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get(); + + get<5>(problem_shape).max_length = get<5>(problem_shape).max_length; + get<5>(problem_shape).total_length = get<5>(problem_shape).total_length; + get<5>(problem_shape).cumulative_length = device_cumulative_seqlen_kv_cache.get(); + + get<4>(problem_shape).max_length = get<4>(problem_shape).max_length; + get<4>(problem_shape).total_length = get<4>(problem_shape).total_length; + get<4>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get(); + + } + + return problem_shape; + } + + // Note that the GemmUniversalAdapter currently doesn't support flash attention, which is why this + // secondary `run` function is required to launch the kernel. + static void run(typename FMHAChunkPrefillKernel::Params params) { + dim3 const block = FMHAChunkPrefillKernel::get_block_shape(); + dim3 const grid = FMHAChunkPrefillKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = FMHAChunkPrefillKernel::SharedStorageSize; + + const auto sycl_block = compat::dim3(block.x, block.y, block.z); + const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z); + +// Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension +#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + using namespace compat::experimental; + auto event = launch>( + launch_policy{sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, + kernel_properties{sycl_exp::sub_group_size}}, + params); +#else + compat::experimental::launch_properties launch_props { + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + compat::experimental::kernel_properties kernel_props{ + sycl::ext::oneapi::experimental::sub_group_size + }; + compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = compat::experimental::launch>(policy, params); +#endif + + EventManager::getInstance().addEvent(event); + } + + cutlass::Status run( + const Options &options, + const cutlass::KernelHardwareInfo &hw_info + ) { + + ProblemShapeType problem_size = initialize(options); + + typename FMHAChunkPrefillKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V, + is_fp8_v ? block_q_scale.get() : nullptr, + is_fp8_v ? block_k_scale.get() : nullptr, + is_fp8_v ? block_v_scale.get() : nullptr, + block_K_cache.get(), stride_K_cache, + block_V_cache.get(), stride_V_cache, + options.use_paged_kv ? paged_kv_cache.page_table.get() : nullptr, + options.use_paged_kv ? paged_kv_cache.page_size : 0, + options.use_paged_kv ? paged_kv_cache.num_pages_per_seq.get() : nullptr, + options.window_left, + options.window_right}, + {options.softmax_scale}, + {block_O.get(), stride_O}, + hw_info}; + + // Define device-global scratch memory + size_t workspace_size = FMHAChunkPrefillKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (!FMHAChunkPrefillKernel::can_implement(arguments)) { + std::cout << "Invalid Problem Size: " << options.batch << 'x' << options.num_heads_q << 'x' << + options.seq_len_qo << 'x' << options.seq_len_kv << 'x' << options.head_size_qk << 'x' << options.head_size_vo + << (options.is_causal ? "xCausal" : "xNonCausal") << (options.is_local_mask ? "xLocalMask" : "xNonLocalMask") << std::endl; + return cutlass::Status::kErrorInvalidProblem; + } + + // Initialize the workspace + CUTLASS_CHECK(FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.get())); + + // Convert host-side arguments to device-side arguments to be passed to the kernel + auto params = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.get()); + + // Run the Flash Attention implementation. + run(params); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) { + return cutlass::Status::kErrorInternal; + } + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + run(params); + } + compat::wait(); + + auto offset = cute::min(options.seq_len_qo, options.seq_len_kv); + auto discard_seq_coord = options.seq_len_qo - offset; + auto full_tile_offset = options.seq_len_kv - offset; + // offset + 1 is going to be ceil_div + auto effective_seq_len_kv = options.seq_len_kv_cache + (options.is_causal ? full_tile_offset + ((offset + 1) / 2.0) : + options.is_local_mask ? (options.window_left + options.window_right) + : options.seq_len_kv); + auto effective_seq_len_qo = options.is_causal ? options.seq_len_qo - discard_seq_coord : options.seq_len_qo; + double cute_time = timer.seconds() / options.iterations; + double flops_qk = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * effective_seq_len_kv * options.head_size_qk; + double flops_pv = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo * effective_seq_len_kv; + double tflops = ((flops_qk + flops_pv) * 1e-12) / cute_time; + double gbps_qk = options.batch * (sizeof(ElementQ) * options.num_heads_q * effective_seq_len_qo * options.head_size_qk + + sizeof(ElementK) * options.num_heads_kv * effective_seq_len_kv * options.head_size_qk); + double gbps_pv = sizeof(ElementV) * options.batch * options.num_heads_kv * effective_seq_len_kv * options.head_size_vo + + sizeof(ElementOutput) * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo; + double gbps = ((gbps_qk + gbps_pv) * 1e-9) / (cute_time); + std::cout << "Batch: " << options.batch << "\tNumHeads_q: " << options.num_heads_q << "\tNumHeads_kv: " << options.num_heads_kv << "\tSeq Length QO: " << options.seq_len_qo + << "\tSeq Length KV: " << options.seq_len_kv << "\tSeq Length KV Cache: " << options.seq_len_kv_cache + << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo + << "\tCausal Mask: " << (options.is_causal ? "true" : "false") << "\tVariable Sequence Length: " << (options.varlen ? "true" : "false") + << "\t Scheduler: " << options.scheduler << "\t Paged KV cache: " << (options.use_paged_kv ? "true" : "false"); + printf("\nPerformance: %4.3f GB/s, %4.3f TFlop/s, %6.4f ms\n\n", gbps, tflops, cute_time * 1000); + } + + return cutlass::Status::kSuccess; + } +}; + +// The default configuration for BF16 +template < + bool Causal, + bool LocalMask, + typename TileShapeQK, + typename TileShapePV, + typename TileShapeOutput, + typename SubgroupLayout, + int PipelineStages, + typename ElementInputQ = bfloat16_t, + typename ElementInputKV = bfloat16_t, + typename MMAOperation = XE_8x16x16_F32BF16BF16F32_TT, + typename GmemTiledCopyQ = XE_2D_U16x8x32_LD_N, + typename GmemTiledCopyK = XE_2D_U16x16x16_LD_T, + typename GmemTiledCopyV = XE_2D_U16x16x32_LD_V, + typename StrideO = cutlass::gemm::TagToStrideC_t, + typename ElementAccumulator = float, + typename ElementComputeEpilogue = float, + typename ElementOutput = bfloat16_t, + typename GmemTiledCopyStore = XE_2D_U16x8x16_ST_N +> +struct FMHAConfig { + + template + static int run(const Options &options) { + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using CollectiveEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillEpilogue< + EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput, + StrideO, ElementOutput, GmemTiledCopyStore>; + using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillSoftmaxEpilogue; + + using ProblemShapeRegular = cute::tuple; + using namespace cutlass::fmha::collective; + using ProblemShapeVarlen = cute::tuple; + using ProblemShapeType = std::conditional_t; + + // Mainloop + using CollectiveMainloop = cutlass::flash_attention::collective::FlashChunkPrefillMma< + GEMMDispatchPolicy, ProblemShapeType, ElementInputQ, cutlass::gemm::TagToStrideA_t, ElementInputKV, + cutlass::gemm::TagToStrideB_t, ElementInputKV, cutlass::gemm::TagToStrideB_t, MMAOperation, TileShapeQK, TileShapePV, SubgroupLayout, + GmemTiledCopyQ, // Q + GmemTiledCopyK, // K + GmemTiledCopyV, // V, + Causal, + LocalMask, + PagedKV>; + + using FMHAChunkPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefillChunk; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + // runner.run(options, hw_info); + return 0; + } + + static int run(const Options &options) { + + if (options.use_paged_kv && !options.varlen) { + return run(options); + } else if(!options.use_paged_kv && options.varlen) { + return run(options); + } else if(!options.use_paged_kv && !options.varlen) { + return run(options); + } else { + return run(options); + } + } +}; diff --git a/include/cutlass/fp8_to_fp16.h b/include/cutlass/fp8_to_fp16.h index b2dd1c564e..7720bf09d0 100644 --- a/include/cutlass/fp8_to_fp16.h +++ b/include/cutlass/fp8_to_fp16.h @@ -38,7 +38,106 @@ #include #include #include -#include +#include +#include + +// Helper device function for E4M3 -> BFLOAT16 bitwise conversion +CUTLASS_DEVICE inline uint16_t +fp8_e4m3_to_fp16_bitwise(uint8_t const& src) { + // E4M3 (1-4-3) constants + constexpr uint32_t e4m3_exp_bias = 7; + // BFLOAT16 (1-8-7) constants + constexpr uint32_t bf16_exp_bias = 127; + + // Unpack FP8 bits + uint16_t sign = static_cast(src & 0x80); + uint16_t exponent = static_cast(src & 0x78) >> 3; + uint16_t mantissa = static_cast(src & 0x07); + + // Reconstruct BFLOAT16 bits + uint16_t bf16_sign = sign << 8; + // Re-bias exponent and shift to BFLOAT16 position + uint16_t bf16_exponent = (exponent - e4m3_exp_bias + bf16_exp_bias) << 7; + // Shift mantissa to BFLOAT16 position + uint16_t bf16_mantissa = mantissa << 4; + + return bf16_sign | bf16_exponent | bf16_mantissa; +} + +// Helper device function for E5M2 -> BFLOAT16 bitwise conversion +CUTLASS_DEVICE inline uint16_t +fp8_e5m2_to_fp16_bitwise(uint8_t const& src) { + // E5M2 (1-5-2) constants + constexpr uint32_t e5m2_exp_bias = 15; + // BFLOAT16 (1-8-7) constants + constexpr uint32_t bf16_exp_bias = 127; + + // Unpack FP8 bits + uint16_t sign = static_cast(src & 0x80); + uint16_t exponent = static_cast(src & 0x7C) >> 2; + uint16_t mantissa = static_cast(src & 0x03); + + // Reconstruct BFLOAT16 bits + uint16_t bf16_sign = sign << 8; + // Re-bias exponent and shift to BFLOAT16 position + uint16_t bf16_exponent = (exponent - e5m2_exp_bias + bf16_exp_bias) << 7; + // Shift mantissa to BFLOAT16 position + uint16_t bf16_mantissa = mantissa << 5; + + return bf16_sign | bf16_exponent | bf16_mantissa; +} + + +template < + typename Encoding, + int VectorizeSize = 8, + typename SrcTensor, + typename DstTensor +> +CUTLASS_DEVICE void +convert_and_descale( + SrcTensor const& src, + DstTensor& dst, + float scale) { + + using SrcVec_u8 = sycl::vec; + using DstVec_u16 = sycl::vec; + + auto src_ptr = reinterpret_cast(src.data()); + auto dst_ptr = reinterpret_cast(dst.data()); + + // Create a SCALAR bfloat16_t for scaling + const cutlass::bfloat16_t scale_bf16 = static_cast(scale); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < cute::size(src) / VectorizeSize; ++i) { + SrcVec_u8 const src_vec_u8 = src_ptr[i]; + DstVec_u16 result_vec_u16; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < VectorizeSize; ++j) { + // 1. Convert FP8 bits to BFLOAT16 bits + uint16_t val_bf16_bits; + if constexpr (std::is_same_v) { + val_bf16_bits = fp8_e4m3_to_fp16_bitwise(src_vec_u8[j]); + } else { + val_bf16_bits = fp8_e5m2_to_fp16_bitwise(src_vec_u8[j]); + } + + // 2. Reinterpret bits as bfloat16_t to perform math + cutlass::bfloat16_t val_bf16 = reinterpret_cast(val_bf16_bits); + + // 3. Apply scaling + val_bf16 *= scale_bf16; + + // 4. Reinterpret back to bits for storage + result_vec_u16[j] = reinterpret_cast(val_bf16); + } + + // 5. Store the final vector of bits + dst_ptr[i] = result_vec_u16; + } +} template CUTLASS_DEVICE void diff --git a/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp index 30a09d7a69..a837f86b85 100644 --- a/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp @@ -411,9 +411,9 @@ struct TestbedImpl { int max_seq_len_q = static_cast(cute::get<3>(problem_size)); int max_seq_len_kv = static_cast(cute::get<4>(problem_size)); int max_seq_len_kv_cache = static_cast(cute::get<5>(problem_size)); - cute::get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - cute::get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; - cute::get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; + cute::get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + cute::get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + cute::get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); diff --git a/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp index ece31f6f7a..13ca1ee3dc 100644 --- a/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp @@ -377,8 +377,8 @@ struct TestbedImpl { if constexpr (isVarLen) { int max_seq_len_q = static_cast(cute::get<3>(problem_size)); int max_seq_len_kv = static_cast(cute::get<4>(problem_size)); - cute::get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - cute::get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; + cute::get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + cute::get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,5,6>(problem_size); diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp index b758d1b8fd..9a70e379bc 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp @@ -362,9 +362,9 @@ struct TestbedImpl { int max_seq_len_q = static_cast(cute::get<3>(problem_size)); int max_seq_len_kv = static_cast(cute::get<4>(problem_size)); int max_seq_len_kv_cache = static_cast(cute::get<5>(problem_size)); - cute::get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - cute::get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; - cute::get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; + cute::get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + cute::get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + cute::get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); diff --git a/tools/util/include/compat.hpp b/tools/util/include/compat.hpp new file mode 100644 index 0000000000..d02362a611 --- /dev/null +++ b/tools/util/include/compat.hpp @@ -0,0 +1,26 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Compat + * + * compat.hpp + * + * Description: + * Main include header for Compat + **************************************************************************/ + +#pragma once + +#include diff --git a/tools/util/include/compat/atomic.hpp b/tools/util/include/compat/atomic.hpp new file mode 100644 index 0000000000..8d64dafc08 --- /dev/null +++ b/tools/util/include/compat/atomic.hpp @@ -0,0 +1,474 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCL compatibility extension + * + * atomic.hpp + * + * Description: + * Atomic functionality for the SYCL compatibility extension + **************************************************************************/ + +// The original source was under the license below: +//==---- atomic.hpp -------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include +#include +#include +#include + +#include + +namespace compat { + +/// Atomically add the value operand to the value at the addr and assign the +/// result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to add to the value at \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_add(T *addr, arith_t operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_add(operand); +} + +/// Atomically subtract the value operand from the value at the addr and +/// assign the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to subtract from the value at \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_sub(T *addr, arith_t operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_sub(operand); +} + +/// Atomically perform a bitwise AND between the value operand and the value +/// at the addr and assign the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to use in bitwise AND operation with the value at +/// the \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_and(T *addr, type_identity_t operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_and(operand); +} + +/// Atomically or the value at the addr with the value operand, and assign +/// the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to use in bitwise OR operation with the value at +/// the \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_or(T *addr, type_identity_t operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_or(operand); +} + +/// Atomically xor the value at the addr with the value operand, and assign +/// the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to use in bitwise XOR operation with the value at +/// the \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_xor(T *addr, type_identity_t operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_xor(operand); +} + +/// Atomically calculate the minimum of the value at addr and the value +/// operand and assign the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand. \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_min(T *addr, type_identity_t operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_min(operand); +} + +/// Atomically calculate the maximum of the value at addr and the value +/// operand and assign the result to the value at addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_fetch_max(T *addr, type_identity_t operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.fetch_max(operand); +} + +/// Atomically set \p operand to the value stored in \p addr, if old value +/// stored in \p addr is equal to zero or greater than \p operand, else decrease +/// the value stored in \p addr. \param [in, out] addr The pointer to the data. +/// \param operand The threshold value. +/// \param memoryOrder The memory ordering used. +/// \returns The old value stored in \p addr. +template +unsigned int atomic_fetch_compare_dec(unsigned int *addr, + unsigned int operand) { + auto atm = + sycl::atomic_ref( + addr[0]); + unsigned int old; + + while (true) { + old = atm.load(); + if (old == 0 || old > operand) { + if (atm.compare_exchange_strong(old, operand)) + break; + } else if (atm.compare_exchange_strong(old, old - 1)) + break; + } + + return old; +} + +/// Atomically increment the value stored in \p addr if old value stored in \p +/// addr is less than \p operand, else set 0 to the value stored in \p addr. +/// \param [in, out] addr The pointer to the data. +/// \param operand The threshold value. +/// \param memoryOrder The memory ordering used. +/// \returns The old value stored in \p addr. +template +inline unsigned int atomic_fetch_compare_inc(unsigned int *addr, + unsigned int operand) { + auto atm = + sycl::atomic_ref( + addr[0]); + unsigned int old; + while (true) { + old = atm.load(); + if (old >= operand) { + if (atm.compare_exchange_strong(old, 0)) + break; + } else if (atm.compare_exchange_strong(old, old + 1)) + break; + } + return old; +} + +/// Atomically exchange the value at the address addr with the value operand. +/// \param [in, out] addr The pointer to the data. +/// \param operand The value to be exchanged with the value pointed by \p addr. +/// \param memoryOrder The memory ordering used. +/// \returns The value at the \p addr before the call. +template +inline T atomic_exchange(T *addr, type_identity_t operand) { + auto atm = + sycl::atomic_ref(addr[0]); + return atm.exchange(operand); +} + +/// Atomically compare the value at \p addr to the value expected and exchange +/// with the value desired if the value at \p addr is equal to the value +/// expected. Returns the value at the \p addr before the call. +/// \param [in, out] addr Multi_ptr. +/// \param expected The value to compare against the value at \p addr. +/// \param desired The value to assign to \p addr if the value at \p addr +/// is expected. +/// \param success The memory ordering used when comparison succeeds. +/// \param fail The memory ordering used when comparison fails. +/// \returns The value at the \p addr before the call. +template +T atomic_compare_exchange_strong( + sycl::multi_ptr addr, type_identity_t expected, + type_identity_t desired, + sycl::memory_order success = sycl::memory_order::relaxed, + sycl::memory_order fail = sycl::memory_order::relaxed) { + auto atm = sycl::atomic_ref(*addr); + + atm.compare_exchange_strong(expected, desired, success, fail); + return expected; +} + +/// Atomically compare the value at \p addr to the value expected and exchange +/// with the value desired if the value at \p addr is equal to the value +/// expected. Returns the value at the \p addr before the call. +/// \param [in] addr The pointer to the data. +/// \param expected The value to compare against the value at \p addr. +/// \param desired The value to assign to \p addr if the value at \p addr is +/// expected. +/// \param success The memory ordering used when comparison succeeds. +/// \param fail The memory ordering used when comparison fails. +/// \returns The value at the \p addr before the call. +template +T atomic_compare_exchange_strong( + T *addr, type_identity_t expected, type_identity_t desired, + sycl::memory_order success = sycl::memory_order::relaxed, + sycl::memory_order fail = sycl::memory_order::relaxed) { + auto atm = + sycl::atomic_ref(addr[0]); + atm.compare_exchange_strong(expected, desired, success, fail); + return expected; +} + +/// Atomic extension to implement standard APIs in std::atomic +namespace detail { +template struct IsValidAtomicType { + static constexpr bool value = + (std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_pointer::value); +}; +} // namespace detail + +template +class atomic { + static_assert( + detail::IsValidAtomicType::value, + "Invalid atomic type. Valid types are int, unsigned int, long, " + "unsigned long, long long, unsigned long long, float, double " + "and pointer types"); + T __d; + +public: + /// default memory synchronization order + static constexpr sycl::memory_order default_read_order = + sycl::atomic_ref::default_read_order; + static constexpr sycl::memory_order default_write_order = + sycl::atomic_ref::default_write_order; + static constexpr sycl::memory_scope default_scope = DefaultScope; + static constexpr sycl::memory_order default_read_modify_write_order = + DefaultOrder; + + /// Default constructor. + constexpr atomic() noexcept = default; + /// Constructor with initialize value. + constexpr atomic(T d) noexcept : __d(d){}; + + /// atomically replaces the value of the referenced object with a non-atomic + /// argument + /// \param operand The value to replace the pointed value. + /// \param memoryOrder The memory ordering used. + /// \param memoryScope The memory scope used. + void store(T operand, sycl::memory_order memoryOrder = default_write_order, + sycl::memory_scope memoryScope = default_scope) noexcept { + sycl::atomic_ref atm(__d); + atm.store(operand, memoryOrder, memoryScope); + } + + /// atomically obtains the value of the referenced object + /// \param memoryOrder The memory ordering used. + /// \param memoryScope The memory scope used. + /// \returns The value of the referenced object + T load(sycl::memory_order memoryOrder = default_read_order, + sycl::memory_scope memoryScope = default_scope) const noexcept { + sycl::atomic_ref atm( + const_cast(__d)); + return atm.load(memoryOrder, memoryScope); + } + + /// atomically replaces the value of the referenced object and obtains the + /// value held previously + /// \param operand The value to replace the pointed value. + /// \param memoryOrder The memory ordering used. + /// \param memoryScope The memory scope used. + /// \returns The value of the referenced object before the call. + T exchange(T operand, + sycl::memory_order memoryOrder = default_read_modify_write_order, + sycl::memory_scope memoryScope = default_scope) noexcept { + + sycl::atomic_ref atm(__d); + return atm.exchange(operand, memoryOrder, memoryScope); + } + + /// atomically compares the value of the referenced object with non-atomic + /// argument and performs atomic exchange if equal or atomic load if not + /// \param expected The value expected to be found in the object referenced by + /// the atomic_ref object + /// \param desired The value to store in the referenced object if it is as + /// expected + /// \param success The memory models for the read-modify-write + /// \param failure The memory models for load operations + /// \param memoryScope The memory scope used. + /// \returns true if the referenced object was successfully changed, false + /// otherwise. + bool compare_exchange_weak( + T &expected, T desired, sycl::memory_order success, + sycl::memory_order failure, + sycl::memory_scope memoryScope = default_scope) noexcept { + sycl::atomic_ref atm(__d); + return atm.compare_exchange_weak(expected, desired, success, failure, + memoryScope); + } + /// \param expected The value expected to be found in the object referenced by + /// the atomic_ref object + /// \param desired The value to store in the referenced + /// object if it is as expected + /// \param memoryOrder The memory synchronization ordering for + /// operations + /// \param memoryScope The memory scope used. + /// \returns true if the referenced object was successfully + /// changed, false otherwise. + bool compare_exchange_weak( + T &expected, T desired, + sycl::memory_order memoryOrder = default_read_modify_write_order, + sycl::memory_scope memoryScope = default_scope) noexcept { + sycl::atomic_ref atm(__d); + return atm.compare_exchange_weak(expected, desired, memoryOrder, + memoryScope); + } + + /// atomically compares the value of the referenced object with non-atomic + /// argument and performs atomic exchange if equal or atomic load if not + /// \param expected The value expected to be found in the object referenced by + /// the atomic_ref object + /// \param desired The value to store in the referenced + /// object if it is as expected + /// \param success The memory models for the + /// read-modify-write + /// \param failure The memory models for load operations + /// \param memoryScope The memory scope used. + /// \returns true if the referenced object was successfully changed, false + /// otherwise. + bool compare_exchange_strong( + T &expected, T desired, sycl::memory_order success, + sycl::memory_order failure, + sycl::memory_scope memoryScope = default_scope) noexcept { + + sycl::atomic_ref atm(__d); + return atm.compare_exchange_strong(expected, desired, success, failure, + memoryScope); + } + /// \param expected The value expected to be found in the object referenced by + /// the atomic_ref object + /// \param desired The value to store in the referenced + /// object if it is as expected + /// \param memoryOrder The memory synchronization ordering for + /// operations + /// \param memoryScope The memory scope used. + /// \returns true if the referenced object was successfully changed, false + /// otherwise. + bool compare_exchange_strong( + T &expected, T desired, + sycl::memory_order memoryOrder = default_read_modify_write_order, + sycl::memory_scope memoryScope = default_scope) noexcept { + sycl::atomic_ref atm(__d); + return atm.compare_exchange_strong(expected, desired, memoryOrder, + memoryScope); + } + + /// atomically adds the argument to the value stored in the atomic object and + /// obtains the value held previously + /// \param operand The other argument of arithmetic addition + /// \param memoryOrder The memory ordering used. + /// \param memoryScope The memory scope used. + /// \returns The value of the referenced object before the call. + T fetch_add(arith_t operand, + sycl::memory_order memoryOrder = default_read_modify_write_order, + sycl::memory_scope memoryScope = default_scope) noexcept { + + auto atm = sycl::atomic_ref(__d); + return atm.fetch_add(operand, memoryOrder, memoryScope); + } + + /// atomically subtracts the argument from the value stored in the atomic + /// object and obtains the value held previously + /// \param operand The other argument of arithmetic subtraction + /// \param memoryOrder The memory ordering used. + /// \param memoryScope The memory scope used. + /// \returns The value of the referenced object before the call. + T fetch_sub(arith_t operand, + sycl::memory_order memoryOrder = default_read_modify_write_order, + sycl::memory_scope memoryScope = default_scope) noexcept { + + auto atm = sycl::atomic_ref(__d); + return atm.fetch_sub(operand, memoryOrder, memoryScope); + } +}; + +} // namespace compat diff --git a/tools/util/include/compat/compat.hpp b/tools/util/include/compat/compat.hpp new file mode 100644 index 0000000000..a2eed0fa55 --- /dev/null +++ b/tools/util/include/compat/compat.hpp @@ -0,0 +1,36 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Compat + * + * compat.hpp + * + * Description: + * Main include internal header for Compat + **************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/tools/util/include/compat/defs.hpp b/tools/util/include/compat/defs.hpp new file mode 100644 index 0000000000..5c578fbab2 --- /dev/null +++ b/tools/util/include/compat/defs.hpp @@ -0,0 +1,94 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Compat + * + * defs.hpp + * + * Description: + * helper aliases and definitions for Compat + * + **************************************************************************/ + +// The original source was under the license below: +//==---- defs.hpp ---------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +template class compat_kernel_name; +template class compat_kernel_scalar; + +#if defined(_MSC_VER) +#define __compat_align__(n) __declspec(align(n)) +#define __compat_inline__ __forceinline +#define __compat_noinline__ __declspec(noinline) +#else +#define __compat_align__(n) __attribute__((aligned(n))) +#define __compat_inline__ __inline__ __attribute__((always_inline)) +#define __compat_noinline__ __attribute__((noinline)) +#endif + +#define COMPAT_COMPATIBILITY_TEMP (900) + +#ifdef _WIN32 +#define COMPAT_EXPORT __declspec(dllexport) +#else +#define COMPAT_EXPORT +#endif + +#define COMPAT_MAJOR_VERSION 0 +#define COMPAT_MINOR_VERSION 2 +#define COMPAT_PATCH_VERSION 0 + +#define COMPAT_MAKE_VERSION(_major, _minor, _patch) \ + ((1E6 * _major) + (1E3 * _minor) + _patch) + +#define COMPAT_VERSION \ + COMPAT_MAKE_VERSION(COMPAT_MAJOR_VERSION, COMPAT_MINOR_VERSION, \ + COMPAT_PATCH_VERSION) + +namespace compat { +enum error_code { success = 0, backend_error = 1, default_error = 999 }; +/// A dummy function introduced to assist auto migration. +/// The SYCLomatic user should replace it with a real error-handling function. +/// SYCL reports errors using exceptions and does not use error codes. +inline const char *get_error_string_dummy(int ec) { + (void)ec; + return ""; // Return the error string for the error code + // ec. +} +} // namespace compat + +#define COMPAT_CHECK_ERROR(expr) \ + [&]() { \ + try { \ + expr; \ + return compat::error_code::success; \ + } catch (sycl::exception const &e) { \ + std::cerr << e.what() << std::endl; \ + return compat::error_code::backend_error; \ + } catch (std::runtime_error const &e) { \ + std::cerr << e.what() << std::endl; \ + return compat::error_code::default_error; \ + } \ + }() diff --git a/tools/util/include/compat/device.hpp b/tools/util/include/compat/device.hpp new file mode 100644 index 0000000000..25e096e281 --- /dev/null +++ b/tools/util/include/compat/device.hpp @@ -0,0 +1,967 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCL compatibility extension + * + * device.hpp + * + * Description: + * Device functionality for the SYCL compatibility extension + **************************************************************************/ +// +// Modifications, Copyright (C) 2025 Intel Corporation +// +// This software and the related documents are Intel copyrighted materials, and +// your use of them is governed by the express license under which they were +// provided to you ("License"). Unless the License provides otherwise, you may +// not use, modify, copy, publish, distribute, disclose or transmit this +// software or the related documents without Intel's prior written permission. +// +// This software and the related documents are provided as is, with no express +// or implied warranties, other than those that are expressly stated in the +// License. +// +// The original source was under the license below: +//==---- device.hpp -------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(__linux__) +#include +#include +#endif +#if defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#endif + +#include +#include +#include +#include + +namespace compat { + +namespace detail { +static void parse_version_string(const std::string &ver, int &major, + int &minor) { + // Version string has the following format: + // a. OpenCL + // b. + // c. e.g gfx1030 + std::string::size_type i = 0; + while (i < ver.size()) { + if (isdigit(ver[i])) + break; + i++; + } + if (i < ver.size()) + major = std::stoi(&(ver[i])); + else + major = 0; + while (i < ver.size()) { + if (ver[i] == '.') + break; + i++; + } + i++; + if (i < ver.size()) + minor = std::stoi(&(ver[i])); + else + minor = 0; +} + +static void get_version(const sycl::device &dev, int &major, int &minor) { + std::string ver = dev.get_info(); + parse_version_string(ver, major, minor); +} + +/// SYCL default exception handler +inline auto exception_handler = [](sycl::exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } catch (sycl::exception const &e) { + std::cerr << "[Compat] Caught asynchronous SYCL exception:" + << std::endl + << e.what() << std::endl + << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + } + } +}; + +} // namespace detail + +using event_ptr = sycl::event *; + +using queue_ptr = sycl::queue *; + +using device_ptr = char *; + +/// Destroy \p event pointed memory. +/// +/// \param event Pointer to the sycl::event address. +static void destroy_event(event_ptr event) { delete event; } + +class device_info { +public: + // get interface + const char *get_name() const { return _name; } + char *get_name() { return _name; } + template , + std::enable_if_t> || + std::is_same_v, + int> = 0> + auto get_max_work_item_sizes() const { + if constexpr (std::is_same_v>) + return _max_work_item_sizes; + else + return _max_work_item_sizes_i; + } + template , + std::enable_if_t> || + std::is_same_v, + int> = 0> + auto get_max_work_item_sizes() { + if constexpr (std::is_same_v>) + return _max_work_item_sizes; + else + return _max_work_item_sizes_i; + } + bool get_host_unified_memory() const { return _host_unified_memory; } + int get_major_version() const { return _major; } + int get_minor_version() const { return _minor; } + int get_integrated() const { return _integrated; } + int get_max_clock_frequency() const { return _frequency; } + int get_max_compute_units() const { return _max_compute_units; } + int get_max_work_group_size() const { return _max_work_group_size; } + int get_max_sub_group_size() const { return _max_sub_group_size; } + int get_max_work_items_per_compute_unit() const { + return _max_work_items_per_compute_unit; + } + int get_max_register_size_per_work_group() const { + return _max_register_size_per_work_group; + } + template || + std::is_same_v, + int> = 0> + auto get_max_nd_range_size() const { + if constexpr (std::is_same_v) + return _max_nd_range_size; + else + return _max_nd_range_size_i; + } + template || + std::is_same_v, + int> = 0> + auto get_max_nd_range_size() { + if constexpr (std::is_same_v) + return _max_nd_range_size; + else + return _max_nd_range_size_i; + } + size_t get_global_mem_size() const { return _global_mem_size; } + size_t get_local_mem_size() const { return _local_mem_size; } + /// Returns the maximum clock rate of device's global memory in kHz. If + /// compiler does not support this API then returns default value 3200000 kHz. + unsigned int get_memory_clock_rate() const { return _memory_clock_rate; } + /// Returns the maximum bus width between device and memory in bits. If + /// compiler does not support this API then returns default value 64 bits. + unsigned int get_memory_bus_width() const { return _memory_bus_width; } + uint32_t get_device_id() const { return _device_id; } + std::array get_uuid() const { return _uuid; } + /// Returns global memory cache size in bytes. + unsigned int get_global_mem_cache_size() const { + return _global_mem_cache_size; + } + int get_image1d_max() const { return _image1d_max; } + auto get_image2d_max() const { return _image2d_max; } + auto get_image2d_max() { return _image2d_max; } + auto get_image3d_max() const { return _image3d_max; } + auto get_image3d_max() { return _image3d_max; } + + // set interface + void set_name(const char *name) { + size_t length = strlen(name); + if (length < device_info::NAME_BUFFER_SIZE) { + std::memcpy(_name, name, length + 1); + } else { + std::memcpy(_name, name, device_info::NAME_BUFFER_SIZE - 1); + _name[255] = '\0'; + } + } + void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes) { + _max_work_item_sizes = max_work_item_sizes; + for (int i = 0; i < 3; ++i) + _max_work_item_sizes_i[i] = max_work_item_sizes[i]; + } + [[deprecated]] void + set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes) { + for (int i = 0; i < 3; ++i) { + _max_work_item_sizes[i] = max_work_item_sizes[i]; + _max_work_item_sizes_i[i] = max_work_item_sizes[i]; + } + } + void set_host_unified_memory(bool host_unified_memory) { + _host_unified_memory = host_unified_memory; + } + void set_major_version(int major) { _major = major; } + void set_minor_version(int minor) { _minor = minor; } + void set_integrated(int integrated) { _integrated = integrated; } + void set_max_clock_frequency(int frequency) { _frequency = frequency; } + void set_max_compute_units(int max_compute_units) { + _max_compute_units = max_compute_units; + } + void set_global_mem_size(size_t global_mem_size) { + _global_mem_size = global_mem_size; + } + void set_local_mem_size(size_t local_mem_size) { + _local_mem_size = local_mem_size; + } + void set_max_work_group_size(int max_work_group_size) { + _max_work_group_size = max_work_group_size; + } + void set_max_sub_group_size(int max_sub_group_size) { + _max_sub_group_size = max_sub_group_size; + } + void + set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit) { + _max_work_items_per_compute_unit = max_work_items_per_compute_unit; + } + void set_max_nd_range_size(int max_nd_range_size[]) { + for (int i = 0; i < 3; i++) { + _max_nd_range_size[i] = max_nd_range_size[i]; + _max_nd_range_size_i[i] = max_nd_range_size[i]; + } + } + void set_max_nd_range_size(sycl::id<3> max_nd_range_size) { + for (int i = 0; i < 3; i++) { + _max_nd_range_size[i] = max_nd_range_size[i]; + _max_nd_range_size_i[i] = max_nd_range_size[i]; + } + } + void set_memory_clock_rate(unsigned int memory_clock_rate) { + _memory_clock_rate = memory_clock_rate; + } + void set_memory_bus_width(unsigned int memory_bus_width) { + _memory_bus_width = memory_bus_width; + } + void + set_max_register_size_per_work_group(int max_register_size_per_work_group) { + _max_register_size_per_work_group = max_register_size_per_work_group; + } + void set_device_id(uint32_t device_id) { _device_id = device_id; } + void set_uuid(std::array uuid) { _uuid = std::move(uuid); } + void set_global_mem_cache_size(unsigned int global_mem_cache_size) { + _global_mem_cache_size = global_mem_cache_size; + } + void set_image1d_max(size_t image_max_buffer_size) { + _image1d_max = image_max_buffer_size; + } + void set_image2d_max(size_t image_max_width_buffer_size, + size_t image_max_height_buffer_size) { + _image2d_max[0] = image_max_width_buffer_size; + _image2d_max[1] = image_max_height_buffer_size; + } + void set_image3d_max(size_t image_max_width_buffer_size, + size_t image_max_height_buffer_size, + size_t image_max_depth_buffer_size) { + _image3d_max[0] = image_max_width_buffer_size; + _image3d_max[1] = image_max_height_buffer_size; + _image3d_max[2] = image_max_depth_buffer_size; + } + +private: + constexpr static size_t NAME_BUFFER_SIZE = 256; + + char _name[device_info::NAME_BUFFER_SIZE]; + sycl::range<3> _max_work_item_sizes; + int _max_work_item_sizes_i[3]; + bool _host_unified_memory = false; + int _major; + int _minor; + int _integrated = 0; + int _frequency; + // Set estimated value 3200000 kHz as default value. + unsigned int _memory_clock_rate = 3200000; + // Set estimated value 64 bits as default value. + unsigned int _memory_bus_width = 64; + unsigned int _global_mem_cache_size; + int _max_compute_units; + int _max_work_group_size; + int _max_sub_group_size; + int _max_work_items_per_compute_unit; + int _max_register_size_per_work_group; + size_t _global_mem_size; + size_t _local_mem_size; + size_t _max_nd_range_size[3]; + int _max_nd_range_size_i[3]; + uint32_t _device_id; + std::array _uuid; + int _image1d_max; + int _image2d_max[2]; + int _image3d_max[3]; +}; + +static int get_major_version(const sycl::device &dev) { + int major, minor; + detail::get_version(dev, major, minor); + return major; +} + +static int get_minor_version(const sycl::device &dev) { + int major, minor; + detail::get_version(dev, major, minor); + return minor; +} + +static inline void +has_capability_or_fail(const sycl::device &dev, + const std::initializer_list &props) { + for (const auto &it : props) { + if (dev.has(it)) + continue; + switch (it) { + case sycl::aspect::fp64: + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "[Compat] 'double' is not supported in '" + + dev.get_info() + + "' device"); + break; + case sycl::aspect::fp16: + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "[Compat] 'half' is not supported in '" + + dev.get_info() + + "' device"); + break; + default: +#define __SYCL_ASPECT(ASPECT, ID) \ + case sycl::aspect::ASPECT: \ + return #ASPECT; +#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID) +#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE) + auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string { + switch (AspectNum) { +#include +#include + default: + return "unknown aspect"; + } + }; +#undef __SYCL_ASPECT_DEPRECATED_ALIAS +#undef __SYCL_ASPECT_DEPRECATED +#undef __SYCL_ASPECT + throw sycl::exception( + sycl::make_error_code(sycl::errc::runtime), + "[Compat] '" + getAspectNameStr(it) + "' is not supported in '" + + dev.get_info() + "' device"); + } + break; + } +} + +/// device extension +class device_ext : public sycl::device { +public: + device_ext() : sycl::device(), _ctx(*this) {} + ~device_ext() { + try { + std::lock_guard lock(m_mutex); + sycl::event::wait(_events); + _queues.clear(); + } catch (std::exception &e) { + __SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~device_ext", e); + } + } + device_ext(const sycl::device &base, bool print_on_async_exceptions = false, + bool in_order = true) + : sycl::device(base), _ctx(*this) { + if (!this->has(sycl::aspect::usm_device_allocations)) { + throw std::invalid_argument( + "Device does not support device USM allocations"); + } + // calls create_queue since we don't have a locked m_mutex + _default_queue = create_queue(print_on_async_exceptions, in_order); + _saved_queue = _default_queue; + } + + bool is_native_host_atomic_supported() { return false; } + int get_major_version() const { return compat::get_major_version(*this); } + + int get_minor_version() const { return compat::get_minor_version(*this); } + + int get_max_compute_units() const { + return get_device_info().get_max_compute_units(); + } + + /// Return the maximum clock frequency of this device in KHz. + int get_max_clock_frequency() const { + return get_device_info().get_max_clock_frequency(); + } + + int get_integrated() const { return get_device_info().get_integrated(); } + + int get_max_sub_group_size() const { + return get_device_info().get_max_sub_group_size(); + } + + int get_max_register_size_per_work_group() const { + return get_device_info().get_max_register_size_per_work_group(); + } + + int get_max_work_group_size() const { + return get_device_info().get_max_work_group_size(); + } + + int get_mem_base_addr_align() const { + return get_info(); + } + + size_t get_global_mem_size() const { + return get_device_info().get_global_mem_size(); + } + + size_t get_local_mem_size() const { + return get_device_info().get_local_mem_size(); + } + + /// Get the number of bytes of free and total memory on the SYCL device. + /// \param [out] free_memory The number of bytes of free memory on the SYCL + /// device. + /// \param [out] total_memory The number of bytes of total memory on the SYCL + /// device. + void get_memory_info(size_t &free_memory, size_t &total_memory) const { + if (!has(sycl::aspect::ext_intel_free_memory)) { + std::cerr << "[Compat] get_memory_info: ext_intel_free_memory is not " + "supported." + << std::endl; + free_memory = 0; + } else { + free_memory = get_info(); + } + total_memory = get_device_info().get_global_mem_size(); + } + + void get_device_info(device_info &out) const { + if (_dev_info) { + out = *_dev_info; + return; + } + + std::lock_guard lock(m_mutex); + device_info prop; + prop.set_name(get_info().c_str()); + + int major, minor; + get_version(major, minor); + prop.set_major_version(major); + prop.set_minor_version(minor); + + prop.set_max_work_item_sizes( + // SYCL 2020-conformant code, max_work_item_sizes is a struct + // templated by an int + get_info>()); + + prop.set_host_unified_memory(has(sycl::aspect::usm_host_allocations)); + + prop.set_max_clock_frequency( + get_info()); + prop.set_max_compute_units( + get_info()); + prop.set_max_work_group_size( + get_info()); + prop.set_global_mem_size(get_info()); + prop.set_local_mem_size(get_info()); + +#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6) + if (has(sycl::aspect::ext_intel_memory_clock_rate)) { + unsigned int tmp = + get_info(); + if (tmp != 0) + prop.set_memory_clock_rate(1000 * tmp); + } + if (has(sycl::aspect::ext_intel_memory_bus_width)) { + prop.set_memory_bus_width( + get_info()); + } + if (has(sycl::aspect::ext_intel_device_id)) { + prop.set_device_id(get_info()); + } + if (has(sycl::aspect::ext_intel_device_info_uuid)) { + prop.set_uuid(get_info()); + } +#elif defined(_MSC_VER) && !defined(__clang__) +#pragma message("get_device_info: querying memory_clock_rate and \ +memory_bus_width are not supported by the compiler used. \ +Use 3200000 kHz as memory_clock_rate default value. \ +Use 64 bits as memory_bus_width default value.") +#else +#warning "get_device_info: querying memory_clock_rate and \ +memory_bus_width are not supported by the compiler used. \ +Use 3200000 kHz as memory_clock_rate default value. \ +Use 64 bits as memory_bus_width default value." +#endif + + size_t max_sub_group_size = 1; + std::vector sub_group_sizes = + get_info(); + + for (const auto &sub_group_size : sub_group_sizes) { + if (max_sub_group_size < sub_group_size) + max_sub_group_size = sub_group_size; + } + + prop.set_max_sub_group_size(max_sub_group_size); + + prop.set_max_work_items_per_compute_unit( + get_info()); +#ifdef SYCL_EXT_ONEAPI_MAX_WORK_GROUP_QUERY + prop.set_max_nd_range_size( + get_info>()); +#else +#if defined(_MSC_VER) && !defined(__clang__) +#pragma message("get_device_info: querying the maximum number \ + of work groups is not supported.") +#else +#warning "get_device_info: querying the maximum number of \ + work groups is not supported." +#endif + int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF}; + prop.set_max_nd_range_size(max_nd_range_size); +#endif + + // Estimates max register size per work group, feel free to update the + // value according to device properties. + prop.set_max_register_size_per_work_group(65536); + + prop.set_global_mem_cache_size( + get_info()); + + prop.set_image1d_max(get_info()); + prop.set_image1d_max(get_info()); + prop.set_image2d_max(get_info(), + get_info()); + prop.set_image3d_max(get_info(), + get_info(), + get_info()); + + _dev_info = prop; + out = prop; + } + + device_info get_device_info() const { + if (!_dev_info) { + this->get_device_info(*_dev_info); + } + return _dev_info.value(); + } + + void reset(bool print_on_async_exceptions = false, bool in_order = true) { + std::lock_guard lock(m_mutex); + // The queues are shared_ptrs and the ref counts of the shared_ptrs increase + // only in wait_and_throw(). If there is no other thread calling + // wait_and_throw(), the queues will be destructed. The destructor waits for + // all commands executing on the queue to complete. It isn't possible to + // destroy a queue immediately. This is a synchronization point in SYCL. + _queues.clear(); + // create new default queue + // calls create_queue_impl since we already have a locked m_mutex + + _saved_queue = _default_queue = + in_order ? create_queue_impl(print_on_async_exceptions, + sycl::property::queue::in_order()) + : create_queue_impl(print_on_async_exceptions); + } + + void set_default_queue(const sycl::queue &q) { + std::lock_guard lock(m_mutex); + _queues.front().get()->wait_and_throw(); + _queues[0] = std::make_shared(q); + if (_saved_queue == _default_queue) + _saved_queue = _queues.front().get(); + _default_queue = _queues.front().get(); + } + + queue_ptr default_queue() { return _default_queue; } + + void queues_wait_and_throw() { + std::unique_lock lock(m_mutex); + std::vector> current_queues(_queues); + lock.unlock(); + for (const auto &q : current_queues) { + q->wait_and_throw(); + } + // Guard the destruct of current_queues to make sure the ref count is safe. + lock.lock(); + } + queue_ptr create_queue(bool print_on_async_exceptions = false, + bool in_order = true) { + std::lock_guard lock(m_mutex); + return in_order ? create_queue_impl(print_on_async_exceptions, + sycl::property::queue::in_order()) + : create_queue_impl(print_on_async_exceptions); + } + void destroy_queue(queue_ptr &queue) { + std::lock_guard lock(m_mutex); + _queues.erase( + std::remove_if(_queues.begin(), _queues.end(), + [=](const std::shared_ptr &q) -> bool { + return q.get() == queue; + }), + _queues.end()); + queue = nullptr; + } + void set_saved_queue(queue_ptr q) { + std::lock_guard lock(m_mutex); + _saved_queue = q; + } + queue_ptr get_saved_queue() const { + std::lock_guard lock(m_mutex); + return _saved_queue; + } + sycl::context get_context() const { return _ctx; } + + /// Util function to check whether a device supports some kinds of + /// sycl::aspect. + void has_capability_or_fail( + const std::initializer_list &props) const { + ::compat::has_capability_or_fail(*this, props); + } + +private: + /// Caller should only be done from functions where the resource \p m_mutex + /// has been acquired. + template + queue_ptr create_queue_impl(bool print_on_async_exceptions = false, + PropertiesT... properties) { + sycl::property_list prop = sycl::property_list( +#ifdef COMPAT_PROFILING_ENABLED + sycl::property::queue::enable_profiling(), +#endif + properties...); + if (print_on_async_exceptions) { + _queues.push_back(std::make_shared( + _ctx, *this, detail::exception_handler, prop)); + } else { + _queues.push_back(std::make_shared(_ctx, *this, prop)); + } + return _queues.back().get(); + } + + void get_version(int &major, int &minor) const { + detail::get_version(*this, major, minor); + } + void add_event(sycl::event event) { + std::lock_guard lock(m_mutex); + _events.push_back(event); + } + friend sycl::event enqueue_free(const std::vector &, + const std::vector &, + sycl::queue); + queue_ptr _default_queue; + queue_ptr _saved_queue; + sycl::context _ctx; + std::vector> _queues; + mutable std::mutex m_mutex; + std::vector _events; + mutable std::optional _dev_info; +}; + +namespace detail { + +static inline unsigned int get_tid() { +#if defined(__linux__) + return syscall(SYS_gettid); +#elif defined(_WIN64) + return GetCurrentThreadId(); +#else +#error "Only support Windows and Linux." +#endif +} + +/// device manager +class dev_mgr { +public: + device_ext ¤t_device() { + unsigned int dev_id = current_device_id(); + check_id(dev_id); + return *_devs[dev_id]; + } + device_ext &cpu_device() const { + std::lock_guard lock(m_mutex); + if (_cpu_device == -1) { + throw std::runtime_error("[Compat] No valid cpu device"); + } else { + return *_devs[_cpu_device]; + } + } + device_ext &get_device(unsigned int id) const { + std::lock_guard lock(m_mutex); + check_id(id); + return *_devs[id]; + } + unsigned int current_device_id() const { + std::lock_guard lock(m_mutex); + auto it = _thread2dev_map.find(get_tid()); + if (it != _thread2dev_map.end()) + return it->second; + return _default_device_id; + } + + /// Select device with a device ID. + /// \param [in] id The id of the device which can + /// be obtained through get_device_id(const sycl::device). + void select_device(unsigned int id) { + std::lock_guard lock(m_mutex); + check_id(id); + _thread2dev_map[get_tid()] = id; + } + unsigned int device_count() { return _devs.size(); } + + unsigned int get_device_id(const sycl::device &dev) { + if (!_devs.size()) { + throw std::runtime_error( + "[Compat] No SYCL devices found in the device list. Device list " + "may have been filtered by compat::filter_device"); + } + unsigned int id = 0; + for (auto dev_item : _devs) { + if (*dev_item == dev) { + return id; + } + id++; + } + throw std::runtime_error("[Compat] The device[" + + dev.get_info() + + "] is filtered out by compat::filter_device " + "in current device list!"); + } + + /// List all the devices with its id in dev_mgr. + void list_devices() const { + for (size_t i = 0; i < _devs.size(); ++i) { + std::cout << "Device " << i << ": " + << _devs[i]->get_info() << std::endl; + } + } + + /// Filter out devices; only keep the device whose name contains one of the + /// subname in \p dev_subnames. + /// May break device id mapping and change current device. It's better to be + /// called before other Compat/SYCL APIs. + void filter(const std::vector &dev_subnames) { + std::lock_guard lock(m_mutex); + auto iter = _devs.begin(); + while (iter != _devs.end()) { + std::string dev_name = (*iter)->get_info(); + bool matched = false; + for (const auto &name : dev_subnames) { + if (dev_name.find(name) != std::string::npos) { + matched = true; + break; + } + } + if (matched) + ++iter; + else + iter = _devs.erase(iter); + } + _cpu_device = -1; + for (unsigned i = 0; i < _devs.size(); ++i) { + if (_devs[i]->is_cpu()) { + _cpu_device = i; + break; + } + } + _thread2dev_map.clear(); +#ifdef COMPAT_VERBOSE + list_devices(); +#endif + } + + /// Select device with a Device Selector + /// \param selector device selector to get the device id from. Defaults to + /// sycl::gpu_selector_v + template + std::enable_if_t< + std::is_invocable_r_v> + select_device(const DeviceSelector &selector = sycl::gpu_selector_v) { + sycl::device selected_device = sycl::device(selector); + unsigned int selected_device_id = get_device_id(selected_device); + select_device(selected_device_id); + } + + /// Returns the instance of device manager singleton. + static dev_mgr &instance() { + static dev_mgr d_m; + return d_m; + } + dev_mgr(const dev_mgr &) = delete; + dev_mgr &operator=(const dev_mgr &) = delete; + dev_mgr(dev_mgr &&) = delete; + dev_mgr &operator=(dev_mgr &&) = delete; + +private: + mutable std::mutex m_mutex; + + dev_mgr() { + sycl::device default_device = sycl::device(sycl::default_selector_v); + _devs.push_back(std::make_shared(default_device)); + + std::vector sycl_all_devs = + sycl::device::get_devices(sycl::info::device_type::all); + // Collect other devices except for the default device. + if (default_device.is_cpu()) + _cpu_device = 0; + for (auto &dev : sycl_all_devs) { + if (dev == default_device) { + continue; + } + _devs.push_back(std::make_shared(dev)); + if (_cpu_device == -1 && dev.is_cpu()) { + _cpu_device = _devs.size() - 1; + } + } +#ifdef COMPAT_VERBOSE + list_devices(); +#endif + } + void check_id(unsigned int id) const { + if (id >= _devs.size()) { + throw std::runtime_error("invalid device id"); + } + } + std::vector> _devs; + /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current + /// thread id in _thread2dev_map, which means default device should be used + /// for the current thread. + const unsigned int _default_device_id = 0; + /// thread-id to device-id map. + std::map _thread2dev_map; + int _cpu_device = -1; +}; + +} // namespace detail + +static inline sycl::queue create_queue(bool print_on_async_exceptions = false, + bool in_order = true) { + return *detail::dev_mgr::instance().current_device().create_queue( + print_on_async_exceptions, in_order); +} + +/// Util function to get the default queue of current device in +/// device manager. +static inline sycl::queue get_default_queue() { + return *detail::dev_mgr::instance().current_device().default_queue(); +} + +/// Util function to change the default queue of the current device in the +/// device manager +/// If the device extension saved queue is the default queue, +/// the previous saved queue will be overwritten as well. +/// This function will be blocking if there are submitted kernels in the +/// previous default queue. +/// @param q New user-defined queue +static inline void set_default_queue(const sycl::queue &q) { + detail::dev_mgr::instance().current_device().set_default_queue(q); +} + +static inline void wait(sycl::queue q = get_default_queue()) { q.wait(); } + +static inline void wait_and_throw(sycl::queue q = get_default_queue()) { + q.wait_and_throw(); +} + +/// Util function to get the id of current device in +/// device manager. +static inline unsigned int get_current_device_id() { + return detail::dev_mgr::instance().current_device_id(); +} + +/// Util function to get the current device. +static inline device_ext &get_current_device() { + return detail::dev_mgr::instance().current_device(); +} + +/// Util function to get a device by id. +static inline device_ext &get_device(unsigned int id) { + return detail::dev_mgr::instance().get_device(id); +} + +/// Util function to get the context of the default queue of current +/// device in device manager. +static inline sycl::context get_default_context() { + return get_current_device().get_context(); +} + +/// Util function to get a CPU device. +static inline device_ext &cpu_device() { + return detail::dev_mgr::instance().cpu_device(); +} + +/// Filter out devices; only keep the device whose name contains one of the +/// subname in \p dev_subnames. +/// May break device id mapping and change current device. It's better to be +/// called before other Compat or SYCL APIs. +static inline void filter_device(const std::vector &dev_subnames) { + detail::dev_mgr::instance().filter(dev_subnames); +} + +/// List all the devices with its id in dev_mgr. +static inline void list_devices() { + detail::dev_mgr::instance().list_devices(); +} + +static inline unsigned int select_device(unsigned int id) { + detail::dev_mgr::instance().select_device(id); + return id; +} + +template +static inline std::enable_if_t< + std::is_invocable_r_v> +select_device(const DeviceSelector &selector = sycl::gpu_selector_v) { + detail::dev_mgr::instance().select_device(selector); +} + +static inline unsigned int get_device_id(const sycl::device &dev) { + return detail::dev_mgr::instance().get_device_id(dev); +} + +static inline unsigned int device_count() { + return detail::dev_mgr::instance().device_count(); +} +} // namespace compat diff --git a/tools/util/include/compat/dims.hpp b/tools/util/include/compat/dims.hpp new file mode 100644 index 0000000000..8da01a39e6 --- /dev/null +++ b/tools/util/include/compat/dims.hpp @@ -0,0 +1,74 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Compat + * + * dims.hpp + * + * Description: + * dim3 functionality for Compat + **************************************************************************/ + +#pragma once + +#include + +#include + +namespace compat { + +class dim3 { +public: + unsigned int x, y, z; + + dim3(const sycl::range<3> &r) : x(r[2]), y(r[1]), z(r[0]) {} + + dim3(const sycl::range<2> &r) : x(r[1]), y(r[0]), z(1) {} + + dim3(const sycl::range<1> &r) : x(r[0]), y(1), z(1) {} + + constexpr dim3(unsigned int x = 1, unsigned int y = 1, unsigned int z = 1) + : x(x), y(y), z(z) {} + + constexpr size_t size() const { return x * y * z; } + + operator sycl::range<3>() const { return sycl::range<3>(z, y, x); } + operator sycl::range<2>() const { + if (z != 1) + throw std::invalid_argument( + "Attempting to convert a 3D dim3 into sycl::range<2>"); + return sycl::range<2>(y, x); + } + operator sycl::range<1>() const { + if (z != 1 || y != 1) + throw std::invalid_argument( + "Attempting to convert a 2D or 3D dim3 into sycl::range<1>"); + return sycl::range<1>(x); + } +}; // namespace dim3 + +inline dim3 operator*(const dim3 &a, const dim3 &b) { + return dim3{a.x * b.x, a.y * b.y, a.z * b.z}; +} + +inline dim3 operator+(const dim3 &a, const dim3 &b) { + return dim3{a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline dim3 operator-(const dim3 &a, const dim3 &b) { + return dim3{a.x - b.x, a.y - b.y, a.z - b.z}; +} + +} // namespace compat diff --git a/tools/util/include/compat/group_utils.hpp b/tools/util/include/compat/group_utils.hpp new file mode 100644 index 0000000000..a473b5f59b --- /dev/null +++ b/tools/util/include/compat/group_utils.hpp @@ -0,0 +1,1270 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCL compatibility extension + * + * group_utils.hpp + * + * Description: + * Group util functionality for the SYCL compatibility extension + **************************************************************************/ + +// The original source was under the license below: +//==---- group_utils.hpp ------------------*- C++ -*--------------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include + +#include +#include + +namespace compat { +namespace group { +namespace detail { + +template +constexpr auto __reduce_over_group(_Args... __args) { + return sycl::reduce_over_group(__args...); +} + +template constexpr auto __group_broadcast(_Args... __args) { + return sycl::group_broadcast(__args...); +} + +template +constexpr auto __exclusive_scan_over_group(_Args... __args) { + return sycl::exclusive_scan_over_group(__args...); +} + +template +constexpr auto __inclusive_scan_over_group(_Args... __args) { + return sycl::inclusive_scan_over_group(__args...); +} + +template +__compat_inline__ T +exclusive_scan(const Item &item, T input, BinaryOperation binary_op, + GroupPrefixCallbackOperation &prefix_callback_op) { + T group_aggregate; + + T output = + detail::__exclusive_scan_over_group(item.get_group(), input, binary_op); + if (item.get_local_linear_id() == item.get_local_range().size() - 1) { + group_aggregate = binary_op(output, input); + } + + group_aggregate = detail::__group_broadcast( + item.get_group(), group_aggregate, item.get_local_range().size() - 1); + + T group_prefix = prefix_callback_op(group_aggregate); + if (item.get_local_linear_id() == 0) { + output = group_prefix; + } else { + output = binary_op(group_prefix, output); + } + + return output; +} + +typedef uint16_t digit_counter_type; +typedef uint32_t packed_counter_type; + +template struct log2 { + enum { VALUE = log2> 1), COUNT + 1>::VALUE }; +}; + +template struct log2 { + enum { VALUE = (1 << (COUNT - 1) < N) ? COUNT : COUNT - 1 }; +}; + +template class radix_rank { +public: + static size_t get_local_memory_size(size_t group_threads) { + return group_threads * PADDED_COUNTER_LANES * sizeof(packed_counter_type); + } + + radix_rank(uint8_t *local_memory) : _local_memory(local_memory) {} + + template + __compat_inline__ void + rank_keys(const Item &item, uint32_t (&keys)[VALUES_PER_THREAD], + int (&ranks)[VALUES_PER_THREAD], int current_bit, int num_bits) { + + digit_counter_type thread_prefixes[VALUES_PER_THREAD]; + digit_counter_type *digit_counters[VALUES_PER_THREAD]; + digit_counter_type *buffer = + reinterpret_cast(_local_memory); + auto g = item.get_group(); + reset_local_memory(item); + + sycl::group_barrier(g, sycl::memory_scope::work_group); + +#pragma unroll + for (int i = 0; i < VALUES_PER_THREAD; ++i) { + uint32_t digit = + ::compat::detail::bfe(keys[i], current_bit, num_bits); + uint32_t sub_counter = digit >> LOG_COUNTER_LANES; + uint32_t counter_lane = digit & (COUNTER_LANES - 1); + + if (DESCENDING) { + sub_counter = PACKING_RATIO - 1 - sub_counter; + counter_lane = COUNTER_LANES - 1 - counter_lane; + } + + digit_counters[i] = + &buffer[counter_lane * item.get_local_range().size() * PACKING_RATIO + + item.get_local_linear_id() * PACKING_RATIO + sub_counter]; + thread_prefixes[i] = *digit_counters[i]; + *digit_counters[i] = thread_prefixes[i] + 1; + } + + sycl::group_barrier(g, sycl::memory_scope::work_group); + + scan_counters(item); + + sycl::group_barrier(g, sycl::memory_scope::work_group); + + for (int i = 0; i < VALUES_PER_THREAD; ++i) { + ranks[i] = thread_prefixes[i] + *digit_counters[i]; + } + } + +private: + template + __compat_inline__ void reset_local_memory(const Item &item) { + packed_counter_type *ptr = + reinterpret_cast(_local_memory); + +#pragma unroll + for (int i = 0; i < PADDED_COUNTER_LANES; ++i) { + ptr[i * item.get_local_range().size() + item.get_local_linear_id()] = 0; + } + } + + template + __compat_inline__ packed_counter_type upsweep(const Item &item) { + packed_counter_type sum = 0; + packed_counter_type *ptr = + reinterpret_cast(_local_memory); + +#pragma unroll + for (int i = 0; i < PADDED_COUNTER_LANES; i++) { + cached_segment[i] = + ptr[item.get_local_linear_id() * PADDED_COUNTER_LANES + i]; + } + +#pragma unroll + for (int i = 0; i < PADDED_COUNTER_LANES; ++i) { + sum += cached_segment[i]; + } + + return sum; + } + + template + __compat_inline__ void + exclusive_downsweep(const Item &item, packed_counter_type raking_partial) { + packed_counter_type *ptr = + reinterpret_cast(_local_memory); + packed_counter_type sum = raking_partial; + +#pragma unroll + for (int i = 0; i < PADDED_COUNTER_LANES; ++i) { + packed_counter_type value = cached_segment[i]; + cached_segment[i] = sum; + sum += value; + } + +#pragma unroll + for (int i = 0; i < PADDED_COUNTER_LANES; ++i) { + ptr[item.get_local_linear_id() * PADDED_COUNTER_LANES + i] = + cached_segment[i]; + } + } + + struct prefix_callback { + __compat_inline__ packed_counter_type + operator()(packed_counter_type block_aggregate) { + packed_counter_type block_prefix = 0; + +#pragma unroll + for (int packed = 1; packed < PACKING_RATIO; packed++) { + block_prefix += block_aggregate + << (sizeof(digit_counter_type) * 8 * packed); + } + + return block_prefix; + } + }; + + template + __compat_inline__ void scan_counters(const Item &item) { + packed_counter_type raking_partial = upsweep(item); + + prefix_callback callback; + packed_counter_type exclusive_partial = exclusive_scan( + item, raking_partial, sycl::ext::oneapi::plus(), + callback); + + exclusive_downsweep(item, exclusive_partial); + } + +private: + static constexpr int PACKING_RATIO = + sizeof(packed_counter_type) / sizeof(digit_counter_type); + static constexpr int LOG_PACKING_RATIO = log2::VALUE; + static constexpr int LOG_COUNTER_LANES = RADIX_BITS - LOG_PACKING_RATIO; + static constexpr int COUNTER_LANES = 1 << LOG_COUNTER_LANES; + static constexpr int PADDED_COUNTER_LANES = COUNTER_LANES + 1; + + packed_counter_type cached_segment[PADDED_COUNTER_LANES]; + uint8_t *_local_memory; +}; + +template struct base_traits { + + static __compat_inline__ U twiddle_in(U key) { + throw std::runtime_error("Not implemented"); + } + static __compat_inline__ U twiddle_out(U key) { + throw std::runtime_error("Not implemented"); + } +}; + +template struct base_traits { + static __compat_inline__ U twiddle_in(U key) { return key; } + static __compat_inline__ U twiddle_out(U key) { return key; } +}; + +template struct base_traits { + static constexpr U HIGH_BIT = U(1) << ((sizeof(U) * 8) - 1); + static __compat_inline__ U twiddle_in(U key) { return key ^ HIGH_BIT; } + static __compat_inline__ U twiddle_out(U key) { return key ^ HIGH_BIT; } +}; + +template struct base_traits { + static constexpr U HIGH_BIT = U(1) << ((sizeof(U) * 8) - 1); + static __compat_inline__ U twiddle_in(U key) { + U mask = (key & HIGH_BIT) ? U(-1) : HIGH_BIT; + return key ^ mask; + } + static __compat_inline__ U twiddle_out(U key) { + U mask = (key & HIGH_BIT) ? HIGH_BIT : U(-1); + return key ^ mask; + } +}; + +template struct traits : base_traits {}; +template <> struct traits : base_traits {}; +template <> struct traits : base_traits {}; +template <> struct traits : base_traits {}; + +template struct power_of_two { + enum { VALUE = ((N & (N - 1)) == 0) }; +}; + +__compat_inline__ uint32_t shr_add(uint32_t x, uint32_t shift, + uint32_t addend) { + return (x >> shift) + addend; +} + +} // namespace detail + +/// Rearranging data partitioned across a work-group. +/// +/// \tparam T The type of the data elements. +/// \tparam ElementsPerWorkItem The number of data elements assigned to a +/// work-item. +template class exchange { +public: + static size_t get_local_memory_size(size_t group_threads) { + size_t padding_values = + (INSERT_PADDING) + ? ((group_threads * ElementsPerWorkItem) >> LOG_LOCAL_MEMORY_BANKS) + : 0; + return (group_threads * ElementsPerWorkItem + padding_values) * sizeof(T); + } + + exchange(uint8_t *local_memory) : _local_memory(local_memory) {} + + // TODO: Investigate if padding is required for performance, + // and if specializations are required for specific target hardware. + static size_t adjust_by_padding(size_t offset) { + + if constexpr (INSERT_PADDING) { + offset = detail::shr_add(offset, LOG_LOCAL_MEMORY_BANKS, offset); + } + return offset; + } + + struct blocked_offset { + template size_t operator()(Item item, size_t i) { + size_t offset = item.get_local_linear_id() * ElementsPerWorkItem + i; + return adjust_by_padding(offset); + } + }; + + struct striped_offset { + template size_t operator()(Item item, size_t i) { + size_t offset = i * item.get_local_range(2) * item.get_local_range(1) * + item.get_local_range(0) + + item.get_local_linear_id(); + return adjust_by_padding(offset); + } + }; + + template struct scatter_offset { + Iterator begin; + scatter_offset(const int (&ranks)[ElementsPerWorkItem]) { + begin = std::begin(ranks); + } + template size_t operator()(Item item, size_t i) const { + // iterator i is expected to be within bounds [0,VALUES_PER_THREAD) + return adjust_by_padding(begin[i]); + } + }; + + /// Inplace rearrange elements from blocked order to striped order. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the + /// blocked \p input across the work-group is: + /// + /// {[0, 1, 2, 3], [4, 5, 6, 7], ..., [508, 509, 510, 511]}. + /// + /// The striped order output is: + /// + /// {[0, 128, 256, 384], [1, 129, 257, 385], ..., [127, 255, 383, 511]}. + /// + /// \tparam Item The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data of each work-item. + template + __compat_inline__ void + blocked_to_striped(Item item, T (&input)[ElementsPerWorkItem]) { + striped_offset get_striped_offset; + blocked_offset get_blocked_offset; + helper_exchange(item, input, input, get_blocked_offset, get_striped_offset); + } + + /// Inplace rearrange elements from striped order to blocked order. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the + /// striped \p input across the work-group is: + /// + /// { [0, 128, 256, 384], [1, 129, 257, 385], ..., [127, 255, 383, 511] }. + /// + /// The blocked order output is: + /// + /// { [0, 1, 2, 3], [4, 5, 6, 7], ..., [508, 509, 510, 511] }. + /// + /// \tparam Item The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data of each work-item. + template + __compat_inline__ void + striped_to_blocked(Item item, T (&input)[ElementsPerWorkItem]) { + blocked_offset get_blocked_offset; + striped_offset get_striped_offset; + helper_exchange(item, input, input, get_striped_offset, get_blocked_offset); + } + + /// Rearrange elements from blocked order to striped order. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the + /// blocked \p input across the work-group is: + /// + /// { [0, 1, 2, 3], [4, 5, 6, 7], ..., [508, 509, 510, 511] }. + /// + /// The striped order output is: + /// + /// { [0, 128, 256, 384], [1, 129, 257, 385], ..., [127, 255, 383, 511] }. + /// + /// \tparam Item The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data of each work-item. + /// \param output The corresponding output data of each work-item. + template + __compat_inline__ void + blocked_to_striped(Item item, T (&input)[ElementsPerWorkItem], + T (&output)[ElementsPerWorkItem]) { + striped_offset get_striped_offset; + blocked_offset get_blocked_offset; + helper_exchange(item, input, output, get_blocked_offset, + get_striped_offset); + } + + /// Rearrange elements from striped order to blocked order. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the + /// striped \p input across the work-group is: + /// + /// { [0, 128, 256, 384], [1, 129, 257, 385], ..., [127, 255, 383, 511] }. + /// + /// The blocked order output is: + /// + /// { [0, 1, 2, 3], [4, 5, 6, 7], ..., [508, 509, 510, 511] }. + /// + /// \tparam Item The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data of each work-item. + /// \param output The corresponding output data of each work-item. + template + __compat_inline__ void + striped_to_blocked(Item item, T (&input)[ElementsPerWorkItem], + T (&output)[ElementsPerWorkItem]) { + blocked_offset get_blocked_offset; + striped_offset get_striped_offset; + helper_exchange(item, input, output, get_striped_offset, + get_blocked_offset); + } + + /// Inplace exchanges data items annotated by rank into blocked arrangement. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the + /// striped \p input across the work-group is: + /// + /// { [0, 128, 256, 384], [1, 129, 257, 385], ..., [127, 255, 383, 511] }. + /// + /// The rank across the work-group is: + /// + /// { [0, 1, 2, 3], [4, 5, 6, 7], ..., [508, 509, 510, 511] }. + /// + /// The blocked order output is: + /// + /// { [0, 1, 2, 3], [4, 5, 6, 7], ..., [508, 509, 510, 511] }. + /// + /// \tparam Item The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data of each work-item. + /// \param ranks The corresponding rank annotation of each work-item. + template + __compat_inline__ void + scatter_to_blocked(Item item, T (&input)[ElementsPerWorkItem], + int (&ranks)[ElementsPerWorkItem]) { + scatter_offset get_scatter_offset(ranks); + blocked_offset get_blocked_offset; + helper_exchange(item, input, input, get_scatter_offset, get_blocked_offset); + } + + /// Inplace exchanges data items annotated by rank into striped arrangement. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the + /// blocked \p input across the work-group is: + /// + /// { [0, 1, 2, 3], [4, 5, 6, 7], ..., [508, 509, 510, 511] }. + /// + /// The rank across the work-group is: + /// + /// { [16, 20, 24, 28], [32, 36, 40, 44], ..., [499, 503, 507, 511] }. + /// + /// The striped order output of each work-item will be: + /// + /// { [0, 128, 256, 384], [1, 129, 257, 385], ..., [127, 255, 383, 511] }. + /// + /// \tparam Item The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data of each work-item. + /// \param ranks The corresponding rank annotation of each work-item. + template + __compat_inline__ void + scatter_to_striped(Item item, T (&input)[ElementsPerWorkItem], + int (&ranks)[ElementsPerWorkItem]) { + scatter_offset get_scatter_offset(ranks); + striped_offset get_striped_offset; + helper_exchange(item, input, input, get_scatter_offset, get_striped_offset); + } + +private: + template + __compat_inline__ void + helper_exchange(Item item, T (&input)[ElementsPerWorkItem], + T (&output)[ElementsPerWorkItem], + offsetFunctorTypeFW &offset_functor_fw, + offsetFunctorTypeRV &offset_functor_rv) { + T *buffer = reinterpret_cast(_local_memory); +#pragma unroll + for (size_t i = 0; i < ElementsPerWorkItem; i++) { + size_t offset = offset_functor_fw(item, i); + buffer[offset] = input[i]; + } + sycl::group_barrier(item.get_group()); +#pragma unroll + for (size_t i = 0; i < ElementsPerWorkItem; i++) { + size_t offset = offset_functor_rv(item, i); + output[i] = buffer[offset]; + } + } + + static constexpr int LOG_LOCAL_MEMORY_BANKS = 4; + static constexpr bool INSERT_PADDING = + (ElementsPerWorkItem > 4) && + (detail::power_of_two::VALUE); + + uint8_t *_local_memory; +}; + +/// The work-group wide radix sort to sort integer data elements +/// assigned to all work-items in the work-group. +/// +/// \tparam T The type of the data elements. +/// \tparam ElementsPerWorkItem The number of data elements assigned to +/// a work-item. +/// \tparam RADIX_BITS The number of radix bits per digit place. +template +class group_radix_sort { + uint8_t *_local_memory; + +public: + group_radix_sort(uint8_t *local_memory) : _local_memory(local_memory) {} + + static size_t get_local_memory_size(size_t group_threads) { + size_t ranks_size = + detail::radix_rank::get_local_memory_size(group_threads); + size_t exchange_size = + exchange::get_local_memory_size(group_threads); + return sycl::max(ranks_size, exchange_size); + } + +private: + template + __compat_inline__ void + helper_sort(const Item &item, T (&keys)[ElementsPerWorkItem], + int begin_bit = 0, int end_bit = 8 * sizeof(T), + bool is_striped = false) { + + uint32_t(&unsigned_keys)[ElementsPerWorkItem] = + reinterpret_cast(keys); + +#pragma unroll + for (int i = 0; i < ElementsPerWorkItem; ++i) { + unsigned_keys[i] = detail::traits::twiddle_in(unsigned_keys[i]); + } + + for (int i = begin_bit; i < end_bit; i += RADIX_BITS) { + int pass_bits = sycl::min(RADIX_BITS, end_bit - begin_bit); + + int ranks[ElementsPerWorkItem]; + detail::radix_rank(_local_memory) + .template rank_keys(item, unsigned_keys, + ranks, i, pass_bits); + + sycl::group_barrier(item.get_group()); + + bool last_iter = i + RADIX_BITS >= end_bit; + if (last_iter && is_striped) { + exchange(_local_memory) + .scatter_to_striped(item, keys, ranks); + + } else { + exchange(_local_memory) + .scatter_to_blocked(item, keys, ranks); + } + + sycl::group_barrier(item.get_group()); + } + +#pragma unroll + for (int i = 0; i < ElementsPerWorkItem; ++i) { + unsigned_keys[i] = detail::traits::twiddle_out(unsigned_keys[i]); + } + } + +public: + /// Performs an ascending work-group wide radix sort over a blocked + /// arrangement of input elements. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the + /// \p input across the work-group is: + /// + /// { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. + /// + /// The ascending order output is: + /// + /// { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. + /// + /// \tparam Item The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data of each work-item. + /// \param begin_bit The beginning (least-significant) bit index needed for + /// key comparison. + /// \param end_bit The past-the-end (most-significant) bit + /// index needed for key comparison. + template + __compat_inline__ void + sort(const Item &item, T (&input)[ElementsPerWorkItem], int begin_bit = 0, + int end_bit = 8 * sizeof(T)) { + helper_sort(item, input, begin_bit, end_bit); + } + + /// Performs an descending work-group wide radix sort over a blocked + /// arrangement of input elements. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the + /// \p input across the work-group is: + /// + /// { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. + /// + /// The descending order output is: + /// + /// { [511,510,509,508], [11,10,9,8], [7,6,5,4], ..., [3,2,1,0] }. + /// + /// \tparam Item The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data of each work-item. + /// \param begin_bit The beginning (least-significant) bit index needed for + /// key comparison. + /// \param end_bit The past-the-end (most-significant) bit + /// index needed for key comparison. + template + __compat_inline__ void + sort_descending(const Item &item, T (&input)[ElementsPerWorkItem], + int begin_bit = 0, int end_bit = 8 * sizeof(T)) { + helper_sort(item, input, begin_bit, end_bit); + } + + /// Performs an ascending radix sort across a blocked arrangement of input + /// elements, leaving them in a striped arrangement. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the + /// \p input across the work-group is: + /// + /// { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. + /// + /// The corresponding output of each work-item will be: + /// + /// { [0,128,256,384], [1,129,257,385], [2,130,258,386], ..., + /// [127,255,383,511] }. + /// + /// \tparam Item The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data of each work-item. + /// \param begin_bit The beginning (least-significant) bit index needed for + /// key comparison. + /// \param end_bit The past-the-end (most-significant) bit + /// index needed for key comparison. + template + __compat_inline__ void + sort_blocked_to_striped(const Item &item, T (&input)[ElementsPerWorkItem], + int begin_bit = 0, int end_bit = 8 * sizeof(T)) { + helper_sort(item, input, begin_bit, end_bit, + /*is_striped=*/true); + } + + /// Performs an descending radix sort across a blocked arrangement of input + /// elements, leaving them in a striped arrangement. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the + /// \p input across the work-group is: + /// + /// { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. + /// + /// The descending striped order output is: + /// + /// { [0,128,256,384], [1,129,257,385], [2,130,258,386], ..., + /// [127,255,383,511] }. + /// + /// \tparam Item The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data of each work-item. + /// \param begin_bit The beginning (least-significant) bit index needed for + /// key comparison. + /// \param end_bit The past-the-end (most-significant) bit + /// index needed for key comparison. + template + __compat_inline__ void sort_descending_blocked_to_striped( + const Item &item, T (&input)[ElementsPerWorkItem], int begin_bit = 0, + int end_bit = 8 * sizeof(T)) { + helper_sort(item, input, begin_bit, end_bit, + /*is_striped=*/true); + } +}; + +/// Load linear segment items into block format across threads +/// Helper for Block Load +enum load_algorithm { + BLOCK_LOAD_DIRECT, + BLOCK_LOAD_STRIPED, +}; + +/// Load a linear segment of elements into a blocked arrangement across the +/// work-group. +/// +/// \tparam T The data type to load. +/// \tparam ElementsPerWorkItem The number of consecutive elements partitioned +/// onto each work-item. +/// \tparam InputIteratorT The random-access iterator type for input \iterator. +/// \tparam ItemT The sycl::nd_item index space class. +/// \param item The calling work-item. +/// \param input_iter The work-group's base input iterator for loading from. +/// \param data Data to load. +template +__compat_inline__ void load_direct_blocked(const ItemT &item, + InputIteratorT input_iter, + T (&data)[ElementsPerWorkItem]) { + size_t work_item_id = item.get_local_linear_id(); +#pragma unroll + for (size_t i = 0; i < ElementsPerWorkItem; i++) + data[i] = input_iter[(work_item_id * ElementsPerWorkItem) + i]; +} + +/// Load a linear segment of elements into a striped arrangement across the +/// work-group. +/// +/// \tparam T The data type to load. +/// \tparam ElementsPerWorkItem The number of consecutive elements partitioned +/// onto each work-item. +/// \tparam InputIteratorT The random-access iterator type for input \iterator. +/// \tparam ItemT The sycl::nd_item index space class. +/// \param item The calling work-item. +/// \param input_iter The work-group's base input iterator for loading from. +/// \param data Data to load. +template +__compat_inline__ void load_direct_striped(const ItemT &item, + InputIteratorT input_iter, + T (&data)[ElementsPerWorkItem]) { + size_t work_group_size = item.get_group().get_local_linear_range(); + size_t work_item_id = item.get_local_linear_id(); +#pragma unroll + for (size_t i = 0; i < ElementsPerWorkItem; i++) + data[i] = input_iter[work_item_id + i * work_group_size]; +} + +/// Load a linear segment of elements into a blocked arrangement across the +/// work-group, guarded by range. +/// +/// \tparam T The data type to load. +/// \tparam ElementsPerWorkItem The number of consecutive elements partitioned +/// onto each work-item. +/// \tparam InputIteratorT The random-access iterator type for input \iterator. +/// \tparam ItemT The sycl::nd_item index space class. +/// \param item The calling work-item. +/// \param input_iter The work-group's base input iterator for loading from. +/// \param data Data to load. +/// \param valid_items Number of valid items to load +template +__compat_inline__ void +load_direct_blocked(const ItemT &item, InputIteratorT input_iter, + T (&data)[ElementsPerWorkItem], int valid_items) { + size_t work_item_id = item.get_local_linear_id(); +#pragma unroll + for (size_t i = 0; i < ElementsPerWorkItem; i++) + if ((work_item_id * ElementsPerWorkItem) + i < valid_items) + data[i] = input_iter[(work_item_id * ElementsPerWorkItem) + i]; +} + +/// Load a linear segment of elements into a striped arrangement across the +/// work-group, guarded by range. +/// +/// \tparam T The data type to load. +/// \tparam ElementsPerWorkItem The number of consecutive elements partitioned +/// onto each work-item. +/// \tparam InputIteratorT The random-access iterator type for input \iterator. +/// \tparam ItemT The sycl::nd_item index space class. +/// \param item The calling work-item. +/// \param input_iter The work-group's base input iterator for loading from. +/// \param data Data to load. +/// \param valid_items Number of valid items to load +template +__compat_inline__ void +load_direct_striped(const ItemT &item, InputIteratorT input_iter, + T (&data)[ElementsPerWorkItem], int valid_items) { + size_t work_group_size = item.get_group().get_local_linear_range(); + size_t work_item_id = item.get_local_linear_id(); +#pragma unroll + for (size_t i = 0; i < ElementsPerWorkItem; i++) + if (work_item_id + (i * work_group_size) < valid_items) + data[i] = input_iter[work_item_id + i * work_group_size]; +} + +/// Store a blocked arrangement of items across a work-group into a linear +/// segment of items. +/// +/// \tparam T The data type to store. +/// \tparam ElementsPerWorkItem The number of consecutive elements partitioned +/// onto each work-item. +/// \tparam OutputIteratorT The random-access iterator type for output. +/// \iterator. +/// \tparam ItemT The sycl::nd_item index space class. +/// \param item The calling work-item. +/// \param output_iter The work-group's base output iterator for writing. +/// \param data Data to store. +template +__compat_inline__ void +store_direct_blocked(const ItemT &item, OutputIteratorT output_iter, + T (&data)[ElementsPerWorkItem]) { + size_t work_item_id = item.get_local_linear_id(); + OutputIteratorT work_item_iter = + output_iter + (work_item_id * ElementsPerWorkItem); +#pragma unroll + for (size_t i = 0; i < ElementsPerWorkItem; i++) + work_item_iter[i] = data[i]; +} + +/// Store a striped arrangement of items across a work-group into a linear +/// segment of items. +/// +/// \tparam T The data type to store. +/// \tparam ElementsPerWorkItem The number of consecutive elements partitioned +/// onto each work-item. +/// \tparam OutputIteratorT The random-access iterator type for output. +/// \iterator. +/// \tparam ItemT The sycl::nd_item index space class. +/// \param item The calling work-item. +/// \param output_iter The work-group's base output iterator for writing. +/// \param items Data to store. +template +__compat_inline__ void +store_direct_striped(const ItemT &item, OutputIteratorT output_iter, + T (&data)[ElementsPerWorkItem]) { + size_t work_group_size = item.get_group().get_local_linear_range(); + size_t work_item_id = item.get_local_linear_id(); + OutputIteratorT work_item_iter = output_iter + work_item_id; +#pragma unroll + for (size_t i = 0; i < ElementsPerWorkItem; i++) + work_item_iter[i * work_group_size] = data[i]; +} + +/// Store a blocked arrangement of items across a work-group into a linear +/// segment of items, guarded by range. +/// +/// \tparam T The data type to store. +/// \tparam ElementsPerWorkItem The number of consecutive elements partitioned +/// onto each work-item. +/// \tparam OutputIteratorT The random-access iterator type for output. +/// \iterator. +/// \tparam ItemT The sycl::nd_item index space class. +/// \param item The calling work-item. +/// \param output_iter The work-group's base output iterator for writing. +/// \param data Data to store. +/// \param valid_items Number of valid items to load +template +__compat_inline__ void +store_direct_blocked(const ItemT &item, OutputIteratorT output_iter, + T (&data)[ElementsPerWorkItem], size_t valid_items) { + size_t work_item_id = item.get_local_linear_id(); + OutputIteratorT work_item_iter = + output_iter + (work_item_id * ElementsPerWorkItem); +#pragma unroll + for (size_t i = 0; i < ElementsPerWorkItem; i++) + if (i + (work_item_id * ElementsPerWorkItem) < valid_items) + work_item_iter[i] = data[i]; +} + +/// Store a striped arrangement of items across a work-group into a linear +/// segment of items, guarded by range. +/// +/// \tparam T The data type to store. +/// \tparam ElementsPerWorkItem The number of consecutive elements partitioned +/// onto each work-item. +/// \tparam OutputIteratorT The random-access iterator type for output. +/// \iterator. +/// \tparam ItemT The sycl::nd_item index space class. +/// \param item The calling work-item. +/// \param output_iter The work-group's base output iterator for writing. +/// \param items Data to store. +/// \param valid_items Number of valid items to load +template +__compat_inline__ void +store_direct_striped(const ItemT &item, OutputIteratorT output_iter, + T (&data)[ElementsPerWorkItem], size_t valid_items) { + size_t work_group_size = item.get_group().get_local_linear_range(); + size_t work_item_id = item.get_local_linear_id(); + OutputIteratorT work_item_iter = output_iter + work_item_id; +#pragma unroll + for (size_t i = 0; i < ElementsPerWorkItem; i++) + if ((i * work_group_size) + work_item_id < valid_items) + work_item_iter[i * work_group_size] = data[i]; +} + +/// Enumerates alternative algorithms for compat::group::group_load to read +/// a linear segment of data from memory into a blocked arrangement across a +/// work-group. +enum class group_load_algorithm { + /// A blocked arrangement of data is read directly from memory. + blocked, + + /// A striped arrangement of data is read directly from memory. + striped +}; + +/// Provide methods for loading a linear segment of items from memory into a +/// blocked arrangement across a work-group. +/// +/// \tparam T The input data type. +/// \tparam ElementsPerWorkItem The number of data elements assigned to a +/// work-item. +/// \tparam LoadAlgorithm The data movement strategy, default is blocked. +template +class group_load { +public: + static size_t get_local_memory_size([[maybe_unused]] size_t work_group_size) { + return 0; + } + group_load(uint8_t *) {} + + /// Load a linear segment of items from memory. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the + /// \p input across the work-group is: + /// + /// 1, 2, 3, 4, 5, 6, 7, ..., 508, 509, 510, 511. + /// + /// The blocked order \p data of each work-item will be: + /// + /// {[0,1,2,3], [4,5,6,7], ..., [508,509,510,511]}. + /// + /// The striped order \p output of each work-item will be: + /// + /// {[0,128,256,384], [1,129,257,385], ..., [127,255,383,511]}. + /// + /// \tparam ItemT The sycl::nd_item index space class. + /// \tparam InputIteratorT The random-access iterator type for input + /// \iterator. + /// \param item The work-item identifier. + /// \param input_iter The work-group's base input iterator for loading from. + /// \param data The data to load. + template + __compat_inline__ void load(const ItemT &item, InputIteratorT input_iter, + T (&data)[ElementsPerWorkItem]) { + if constexpr (LoadAlgorithm == group_load_algorithm::blocked) { + load_direct_blocked( + item, input_iter, data); + } else if constexpr (LoadAlgorithm == group_load_algorithm::striped) { + load_direct_striped( + item, input_iter, data); + } + } + + /// Load a linear segment of items from memory, guarded by range. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and + /// valid_items is 5, the \p input across the work-group is: + /// + /// 0, 1, 2, 3, 4, 5, 6, 7, ..., 508, 509, 510, 511. + /// + /// The blocked order \p data of each work-item will be: + /// + /// {[0,1,2,3], [4,?,?,?], ..., [?,?,?,?]}. + /// + /// The striped order \p output of each work-item will be: + /// + /// {[0,?,?,?], [1,?,?,?], [2,?,?,?], [3,?,?,?] ..., [?,?,?,?]}. + /// + /// \tparam ItemT The sycl::nd_item index space class. + /// \tparam InputIteratorT The random-access iterator type for input + /// \iterator. + /// \param item The work-item identifier. + /// \param input_iter The work-group's base input iterator for loading from. + /// \param data The data to load. + /// \param valid_items Number of valid items to load + template + __compat_inline__ void load(const ItemT &item, InputIteratorT input_iter, + T (&data)[ElementsPerWorkItem], + int valid_items) { + if constexpr (LoadAlgorithm == group_load_algorithm::blocked) { + load_direct_blocked( + item, input_iter, data, valid_items); + } else if constexpr (LoadAlgorithm == group_load_algorithm::striped) { + load_direct_striped( + item, input_iter, data, valid_items); + } + } +}; + +/// Enumerates alternative algorithms for compat::group::group_load to write +/// a blocked arrangement of items across a work-group to a linear segment of +/// memory. +enum class group_store_algorithm { + /// A blocked arrangement of data is written directly to memory. + blocked, + + /// A striped arrangement of data is written directly to memory. + striped, +}; + +/// Provide methods for writing a blocked arrangement of elements partitioned +/// across a work-group to a linear segment of memory. +/// +/// \tparam T The output data type. +/// \tparam ElementsPerWorkItem The number of data elements assigned to a +/// work-item. +/// \tparam StoreAlgorithm The data movement strategy, default is blocked. +template +class group_store { +public: + static size_t get_local_memory_size([[maybe_unused]] size_t work_group_size) { + return 0; + } + group_store(uint8_t *) {} + + /// Store items into a linear segment of memory. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and the + /// \p input across the work-group is: + /// + /// {[0,1,2,3], [4,5,6,7], ..., [508,509,510,511]}. + /// + /// The blocked order \p output will be: + /// + /// 1, 2, 3, 4, 5, 6, 7, ..., 508, 509, 510, 511. + /// + /// The striped order \p output will be: + /// + /// 0, 128, 256, 384, 1, 129, 257, 385, ..., 127, 255, 383, 511. + /// + /// \tparam ItemT The sycl::nd_item index space class. + /// \tparam OutputIteratorT The random-access iterator type for \p output + /// iterator. + /// \param item The work-item identifier. + /// \param input The input data of each work-item. + /// \param data The data to store. + template + __compat_inline__ void store(const ItemT &item, + OutputIteratorT output_iter, + T (&data)[ElementsPerWorkItem]) { + if constexpr (StoreAlgorithm == group_store_algorithm::blocked) { + store_direct_blocked( + item, output_iter, data); + } else if constexpr (StoreAlgorithm == group_store_algorithm::striped) { + store_direct_striped( + item, output_iter, data); + } + } + + /// Store items into a linear segment of memory, guarded by range. + /// + /// Suppose 512 integer data elements partitioned across 128 work-items, where + /// each work-item owns 4 ( \p ElementsPerWorkItem ) data elements and + /// \p valid_items is 5, the \p output across the work-group is: + /// + /// {[0,0,0,0], [0,0,0,0], ..., [0,0,0,0]}. + /// + /// The blocked order \p output will be: + /// + /// 0, 1, 2, 3, 4, 5, 0, 0, ..., 0, 0, 0, 0. + /// + /// The striped order \p output will be: + /// + /// 0, 4, 8, 12, 16, 0, 0, 0, ..., 0, 0, 0, 0. + /// + /// \tparam ItemT The sycl::nd_item index space class. + /// \tparam OutputIteratorT The random-access iterator type for \p output + /// iterator. + /// \param item The work-item identifier. + /// \param input The input data of each work-item. + /// \param data The data to store. + /// \param valid_items Number of valid items to load + template + __compat_inline__ void + store(const ItemT &item, OutputIteratorT output_iter, + T (&data)[ElementsPerWorkItem], size_t valid_items) { + if constexpr (StoreAlgorithm == group_store_algorithm::blocked) { + store_direct_blocked( + item, output_iter, data, valid_items); + } else if constexpr (StoreAlgorithm == group_store_algorithm::striped) { + store_direct_striped( + item, output_iter, data, valid_items); + } + } +}; + +/// The work-group wide shuffle operations that allow work-items to exchange +/// data elements with other work-items within the same work-group. +/// +/// \tparam T The type of the data elements. +/// \tparam group_dim_0 The first dimension size of the work-group. +/// \tparam group_dim_1 The second dimension size of the work-group. +/// \tparam group_dim_2 The third dimension size of the work-group. +template +class group_shuffle { + T *_local_memory = nullptr; + static constexpr size_t group_work_items = + group_dim_0 * group_dim_1 * group_dim_2; + +public: + static constexpr size_t get_local_memory_size(size_t work_group_size) { + return sizeof(T) * work_group_size; + } + group_shuffle(uint8_t *local_memory) : _local_memory((T *)local_memory) {} + + /// Selects a value from a work-item at a given distance in the work-group + /// and stores the value in the output. + /// + /// \tparam ItemT The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input from the calling work-item. + /// \param output The output where the selected data will be stored. + /// \param distance The distance of work-items to look ahead or behind in the + /// work-group. + template + __compat_inline__ void select(const ItemT &item, T input, T &output, + int distance = 1) { + auto g = item.get_group(); + size_t id = g.get_local_linear_id(); + _local_memory[id] = input; + + sycl::group_barrier(g, sycl::memory_scope::work_group); + + const int target_id = static_cast(id) + distance; + if ((target_id >= 0) && (target_id < group_work_items)) { + output = _local_memory[static_cast(target_id)]; + } + } + /// Selects a value from a work-item at a given distance in the work-group + /// and stores the value in the output, using a wrapped index to handle + /// overflow. + /// + /// \tparam ItemT The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data to be selected. + /// \param output The output where the selected data will be stored. + /// \param distance The number of work-items to look ahead in the + /// work-group. + template + __compat_inline__ void select2(const ItemT &item, T input, T &output, + unsigned int distance = 1) { + auto g = item.get_group(); + size_t id = g.get_local_linear_id(); + _local_memory[id] = input; + + sycl::group_barrier(g, sycl::memory_scope::work_group); + + unsigned int offset = id + distance; + if (offset >= group_work_items) + offset -= group_work_items; + + output = _local_memory[offset]; + } + /// Performs a shuffle operation to move data to the right across the + /// work-items, shifting elements in a work-item array by one position to the + /// right. + /// + /// \tparam ElementsPerWorkItem The number of data elements per work-item. + /// \tparam ItemT The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data to be shuffled. + /// \param output The array that will store the shuffle result. + template + __compat_inline__ void shuffle_right(const ItemT &item, + T (&input)[ElementsPerWorkItem], + T (&output)[ElementsPerWorkItem]) { + auto g = item.get_group(); + size_t id = g.get_local_linear_id(); + _local_memory[id] = input[ElementsPerWorkItem - 1]; + + sycl::group_barrier(g, sycl::memory_scope::work_group); + +#pragma unroll + for (int index = ElementsPerWorkItem - 1; index > 0; --index) + output[index] = input[index - 1]; + + if (id > 0) + output[0] = _local_memory[id - 1]; + } + /// Performs a shuffle operation to move data to the right across the + /// work-items, storing the suffix of the group after the shuffle operation. + /// + /// \tparam ElementsPerWorkItem The number of data elements per work-item. + /// \tparam ItemT The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data to be shuffled. + /// \param output The array that will store the shuffle result. + /// \param group_suffix The suffix of the group after the shuffle. + template + __compat_inline__ void + shuffle_right(const ItemT &item, T (&input)[ElementsPerWorkItem], + T (&output)[ElementsPerWorkItem], T &group_suffix) { + shuffle_right(item, input, output); + group_suffix = _local_memory[group_work_items - 1]; + } + /// Performs a shuffle operation to move data to the left across the + /// work-items, shifting elements in a work-item array by one position to the + /// left. + /// + /// \tparam ElementsPerWorkItem The number of data elements per work-item. + /// \tparam ItemT The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data to be shuffled. + /// \param output The array that will store the shuffle result. + template + __compat_inline__ void shuffle_left(const ItemT &item, + T (&input)[ElementsPerWorkItem], + T (&output)[ElementsPerWorkItem]) { + auto g = item.get_group(); + size_t id = g.get_local_linear_id(); + _local_memory[id] = input[0]; + + sycl::group_barrier(g, sycl::memory_scope::work_group); + +#pragma unroll + for (int index = 0; index < ElementsPerWorkItem - 1; index++) + output[index] = input[index + 1]; + + if (id < group_work_items - 1) + output[ElementsPerWorkItem - 1] = _local_memory[id + 1]; + } + /// Performs a shuffle operation to move data to the left across the + /// work-items, storing the prefix of the group before the shuffle operation. + /// + /// \tparam ElementsPerWorkItem The number of data elements per work-item. + /// \tparam ItemT The work-item identifier type. + /// \param item The work-item identifier. + /// \param input The input data to be shuffled. + /// \param output The array that will store the shuffle result. + /// \param group_prefix The prefix of the group before the shuffle. + template + __compat_inline__ void + shuffle_left(const ItemT &item, T (&input)[ElementsPerWorkItem], + T (&output)[ElementsPerWorkItem], T &group_prefix) { + shuffle_left(item, input, output); + group_prefix = _local_memory[0]; + } +}; +} // namespace group +} // namespace compat diff --git a/tools/util/include/compat/id_query.hpp b/tools/util/include/compat/id_query.hpp new file mode 100644 index 0000000000..120b1a5b29 --- /dev/null +++ b/tools/util/include/compat/id_query.hpp @@ -0,0 +1,71 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCL compatibility extension + * + * id_query.hpp + * + * Description: + * id_query functionality for the SYCL compatibility extension + **************************************************************************/ + +#pragma once + +#include +#include + +namespace compat { + +using sycl::ext::oneapi::this_work_item::get_nd_item; + +inline void wg_barrier() { get_nd_item<3>().barrier(); } + +namespace local_id { +inline size_t x() { return get_nd_item<3>().get_local_id(2); } +inline size_t y() { return get_nd_item<3>().get_local_id(1); } +inline size_t z() { return get_nd_item<3>().get_local_id(0); } +} // namespace local_id + +namespace local_range { +inline size_t x() { return get_nd_item<3>().get_local_range(2); } +inline size_t y() { return get_nd_item<3>().get_local_range(1); } +inline size_t z() { return get_nd_item<3>().get_local_range(0); } +} // namespace local_range + +namespace work_group_id { +inline size_t x() { return get_nd_item<3>().get_group(2); } +inline size_t y() { return get_nd_item<3>().get_group(1); } +inline size_t z() { return get_nd_item<3>().get_group(0); } +} // namespace work_group_id + +namespace work_group_range { +inline size_t x() { return get_nd_item<3>().get_group_range(2); } +inline size_t y() { return get_nd_item<3>().get_group_range(1); } +inline size_t z() { return get_nd_item<3>().get_group_range(0); } +} // namespace work_group_range + +namespace global_range { +inline size_t x() { return get_nd_item<3>().get_global_range(2); } +inline size_t y() { return get_nd_item<3>().get_global_range(1); } +inline size_t z() { return get_nd_item<3>().get_global_range(0); } +} // namespace global_range + +namespace global_id { +inline size_t x() { return get_nd_item<3>().get_global_id(2); } +inline size_t y() { return get_nd_item<3>().get_global_id(1); } +inline size_t z() { return get_nd_item<3>().get_global_id(0); } +} // namespace global_id + +} // namespace compat diff --git a/tools/util/include/compat/kernel.hpp b/tools/util/include/compat/kernel.hpp new file mode 100644 index 0000000000..b9851f1d10 --- /dev/null +++ b/tools/util/include/compat/kernel.hpp @@ -0,0 +1,470 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCL compatibility extension + * + * kernel.hpp + * + * Description: + * kernel functionality for the SYCL compatibility extension. + **************************************************************************/ + +// The original source was under the license below: +//==---- kernel.hpp -------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#ifdef _WIN32 +#include +#include +#else +#include +#endif + +#if defined(__has_include) && __has_include() +#include +#elif defined(__has_include) && __has_include() +#include +#else +#error "SYCLomatic runtime requires C++ filesystem support" +#endif + +#include +#include + +#include +#include +#include +#include + +namespace compat { + +typedef void (*kernel_functor)(sycl::queue &, const sycl::nd_range<3> &, + unsigned int, void **, void **); + +struct kernel_function_info { + int max_work_group_size = 0; +}; + +static inline void get_kernel_function_info(kernel_function_info *kernel_info, + const void *function) { + kernel_info->max_work_group_size = + detail::dev_mgr::instance() + .current_device() + .get_info(); +} + +static inline kernel_function_info +get_kernel_function_info(const void *function) { + kernel_function_info kernel_info; + kernel_info.max_work_group_size = + detail::dev_mgr::instance() + .current_device() + .get_info(); + return kernel_info; +} + +namespace detail { + +#if defined(__has_include) && __has_include() +namespace fs = std::filesystem; +#else +namespace fs = std::experimental::filesystem; +#endif + +/// Write data to temporary file and return absolute path to temporary file. +/// Temporary file is created in a temporary directory both of which have random +/// names with only the user having access permissions. Only one temporary file +/// will be created in the temporary directory. +static inline fs::path write_data_to_file(char const *const data, size_t size) { + std::error_code ec; + + if (sizeof(size_t) >= sizeof(std::streamsize) && + size > (std::numeric_limits::max)()) + throw std::runtime_error("[Compat] data file too large"); + + // random number generator + std::random_device dev; + std::mt19937 prng(dev()); + std::uniform_int_distribution rand(0); + + // find temporary directory + auto tmp_dir = fs::temp_directory_path(ec); + if (ec) + throw std::runtime_error("[Compat] could not find temporary directory"); + + // create private directory + std::stringstream directory; + fs::path directory_path; + constexpr int max_attempts = 5; + int i; + + for (i = 0; i < max_attempts; i++) { + directory << std::hex << rand(prng); + directory_path = tmp_dir / directory.str(); + if (fs::create_directory(directory_path)) { + break; + } + } + if (i == max_attempts) + throw std::runtime_error("[Compat] could not create directory"); + + // only allow owner permissions to private directory + fs::permissions(directory_path, fs::perms::owner_all, ec); + if (ec) + throw std::runtime_error( + "[Compat] could not set directory permissions"); + + // random filename in private directory + std::stringstream filename; + filename << std::hex << rand(prng); +#ifdef _WIN32 + auto filepath = directory_path / (filename.str() + ".dll"); +#else + auto filepath = directory_path / filename.str(); +#endif + + // write data to temporary file + auto outfile = std::ofstream(filepath, std::ios::out | std::ios::binary); + if (outfile) { + // only allow program to write file + fs::permissions(filepath, fs::perms::owner_write, ec); + if (ec) + throw std::runtime_error("[Compat] could not set permissions"); + + outfile.write(data, size); + if (!outfile.good()) + throw std::runtime_error("[Compat] could not write data"); + outfile.close(); + + // only allow program to read/execute file + fs::permissions(filepath, fs::perms::owner_read | fs::perms::owner_exec, + ec); + if (ec) + throw std::runtime_error("[Compat] could not set permissions"); + } else + throw std::runtime_error("[Compat] could not write data"); + + // check temporary file contents + auto infile = std::ifstream(filepath, std::ios::in | std::ios::binary); + if (infile) { + bool mismatch = false; + size_t cnt = 0; + + while (1) { + char c; + infile.get(c); + if (infile.eof()) + break; + if (c != data[cnt++]) + mismatch = true; + } + if (cnt != size || mismatch) + throw std::runtime_error( + "[Compat] file contents not written correctly"); + } else + throw std::runtime_error("[Compat] could not validate file"); + + if (!filepath.is_absolute()) + throw std::runtime_error("[Compat] temporary filepath is not absolute"); + + return filepath; +} + +static inline uint16_t extract16(unsigned char const *const ptr) { + uint16_t ret = 0; + + ret |= static_cast(ptr[0]) << 0; + ret |= static_cast(ptr[1]) << 8; + + return (ret); +} + +static inline uint32_t extract32(unsigned char const *const ptr) { + uint32_t ret = 0; + + ret |= static_cast(ptr[0]) << 0; + ret |= static_cast(ptr[1]) << 8; + ret |= static_cast(ptr[2]) << 16; + ret |= static_cast(ptr[3]) << 24; + + return (ret); +} + +static inline uint64_t extract64(unsigned char const *const ptr) { + uint64_t ret = 0; + + ret |= static_cast(ptr[0]) << 0; + ret |= static_cast(ptr[1]) << 8; + ret |= static_cast(ptr[2]) << 16; + ret |= static_cast(ptr[3]) << 24; + ret |= static_cast(ptr[4]) << 32; + ret |= static_cast(ptr[5]) << 40; + ret |= static_cast(ptr[6]) << 48; + ret |= static_cast(ptr[7]) << 56; + + return (ret); +} + +static inline uint64_t get_lib_size(char const *const blob) { +#ifdef _WIN32 + /////////////////////////////////////////////////////////////////////// + // Analyze DOS stub + unsigned char const *const ublob = + reinterpret_cast(blob); + if (ublob[0] != 0x4d || ublob[1] != 0x5a) { + throw std::runtime_error("[Compat] blob is not a Windows DLL."); + } + uint32_t pe_header_offset = extract32(ublob + 0x3c); + + /////////////////////////////////////////////////////////////////////// + // Ananlyze PE-header + unsigned char const *const pe_header = ublob + pe_header_offset; + + // signature + uint32_t pe_signature = extract32(pe_header + 0); + if (pe_signature != 0x00004550) { + throw std::runtime_error( + "[Compat] PE-header signature is not 0x00004550"); + } + + // machine + uint16_t machine = extract16(pe_header + 4); + if (machine != 0x8664) { + throw std::runtime_error("[Compat] only DLLs for x64 supported"); + } + + // number of sections + uint16_t number_of_sections = extract16(pe_header + 6); + + // sizeof optional header + uint16_t sizeof_optional_header = extract16(pe_header + 20); + + // magic + uint16_t magic = extract16(pe_header + 24); + if (magic != 0x10b && magic != 0x20b) { + throw std::runtime_error("[Compat] MAGIC is not 0x010b or 0x020b"); + } + + /////////////////////////////////////////////////////////////////////// + // Analyze tail of optional header + constexpr int coff_header_size = 24; + + unsigned char const *const tail_of_optional_header = + pe_header + coff_header_size + sizeof_optional_header; + if (extract64(tail_of_optional_header - 8) != 0) { + throw std::runtime_error("Optional header not zero-padded"); + } + + /////////////////////////////////////////////////////////////////////// + // Analyze last section header + constexpr int section_header_size = 40; + unsigned char const *const last_section_header = + tail_of_optional_header + section_header_size * (number_of_sections - 1); + + uint32_t sizeof_raw_data = extract32(last_section_header + 16); + uint32_t pointer_to_raw_data = extract32(last_section_header + 20); + + return sizeof_raw_data + pointer_to_raw_data; +#else + if (blob[0] != 0x7F || blob[1] != 'E' || blob[2] != 'L' || blob[3] != 'F') + throw std::runtime_error("[Compat] blob is not in ELF format"); + + if (blob[4] != 0x02) + throw std::runtime_error("[Compat] only 64-bit headers are supported"); + + if (blob[5] != 0x01) + throw std::runtime_error( + "[Compat] only little-endian headers are supported"); + + unsigned char const *const ublob = + reinterpret_cast(blob); + uint64_t e_shoff = extract64(ublob + 0x28); + uint16_t e_shentsize = extract16(ublob + 0x3A); + uint16_t e_shnum = extract16(ublob + 0x3C); + + return e_shoff + (e_shentsize * e_shnum); +#endif +} + +#ifdef _WIN32 +class path_lib_record { +public: + void operator=(const path_lib_record &) = delete; + ~path_lib_record() { + for (auto entry : lib_to_path) { + FreeLibrary(static_cast(entry.first)); + fs::permissions(entry.second, fs::perms::owner_all); + fs::remove_all(entry.second.remove_filename()); + } + } + static void record_lib_path(fs::path path, void *library) { + lib_to_path[library] = path; + } + static void remove_lib(void *library) { + auto path = lib_to_path[library]; + std::error_code ec; + + FreeLibrary(static_cast(library)); + fs::permissions(path, fs::perms::owner_all); + if (fs::remove_all(path.remove_filename(), ec) != 2 || ec) + // one directory and one temporary file should have been deleted + throw std::runtime_error("[Compat] directory delete failed"); + + lib_to_path.erase(library); + } + +private: + static inline std::unordered_map lib_to_path; +}; +#endif + +} // namespace detail + +class kernel_library { +public: + constexpr kernel_library() : ptr{nullptr} {} + constexpr kernel_library(void *ptr) : ptr{ptr} {} + + operator void *() const { return ptr; } + +private: + void *ptr; +#ifdef _WIN32 + static inline detail::path_lib_record single_instance_to_trigger_destructor; +#endif +}; + +namespace detail { + +static inline kernel_library load_dl_from_data(char const *const data, + size_t size) { + fs::path filename = write_data_to_file(data, size); +#ifdef _WIN32 + void *so = LoadLibraryW(filename.wstring().c_str()); +#else + void *so = dlopen(filename.c_str(), RTLD_LAZY); +#endif + if (so == nullptr) + throw std::runtime_error("[Compat] failed to load kernel library"); + +#ifdef _WIN32 + detail::path_lib_record::record_lib_path(filename, so); +#else + std::error_code ec; + + // Windows DLL cannot be deleted while in use + if (fs::remove_all(filename.remove_filename(), ec) != 2 || ec) + // one directory and one temporary file should have been deleted + throw std::runtime_error("[Compat] directory delete failed"); +#endif + + return so; +} + +} // namespace detail + +/// Load kernel library and return a handle to use the library. +/// \param [in] name The name of the library. +static inline kernel_library load_kernel_library(const std::string &name) { + std::ifstream ifs; + ifs.open(name, std::ios::in | std::ios::binary); + + std::stringstream buffer; + buffer << ifs.rdbuf(); + + const std::string buffer_string = buffer.str(); + return detail::load_dl_from_data(buffer_string.c_str(), buffer_string.size()); +} + +/// Load kernel library whose image is alreay in memory and return a handle to +/// use the library. +/// \param [in] image A pointer to the image in memory. +static inline kernel_library load_kernel_library_mem(char const *const image) { + const size_t size = detail::get_lib_size(image); + + return detail::load_dl_from_data(image, size); +} + +/// Unload kernel library. +/// \param [in,out] library Handle to the library to be closed. +static inline void unload_kernel_library(const kernel_library &library) { +#ifdef _WIN32 + detail::path_lib_record::remove_lib(library); +#else + dlclose(library); +#endif +} + +class kernel_function { +public: + constexpr kernel_function() : ptr{nullptr} {} + constexpr kernel_function(kernel_functor ptr) : ptr{ptr} {} + + operator void *() const { return ((void *)ptr); } + + void operator()(sycl::queue &q, const sycl::nd_range<3> &range, + unsigned int local_mem_size, void **args, void **extra) { + ptr(q, range, local_mem_size, args, extra); + } + +private: + kernel_functor ptr; +}; + +/// Find kernel function in a kernel library and return its address. +/// \param [in] library Handle to the kernel library. +/// \param [in] name Name of the kernel function. +static inline kernel_function get_kernel_function(kernel_library &library, + const std::string &name) { +#ifdef _WIN32 + kernel_functor fn = reinterpret_cast( + GetProcAddress(static_cast(static_cast(library)), + (name + std::string("_wrapper")).c_str())); +#else + kernel_functor fn = reinterpret_cast( + dlsym(library, (name + std::string("_wrapper")).c_str())); +#endif + if (fn == nullptr) + throw std::runtime_error("[Compat] failed to get function"); + return fn; +} + +/// Invoke a kernel function. +/// \param [in] function kernel function. +/// \param [in] queue SYCL queue used to execute kernel +/// \param [in] group_range SYCL group range +/// \param [in] local_range SYCL local range +/// \param [in] local_mem_size The size of local memory required by the kernel +/// function. +/// \param [in] kernel_params Array of pointers to kernel arguments. +/// \param [in] extra Extra arguments. +static inline void invoke_kernel_function(kernel_function &function, + sycl::queue &queue, + sycl::range<3> group_range, + sycl::range<3> local_range, + unsigned int local_mem_size, + void **kernel_params, void **extra) { + function(queue, sycl::nd_range<3>(group_range * local_range, local_range), + local_mem_size, kernel_params, extra); +} + +} // namespace compat diff --git a/tools/util/include/compat/launch.hpp b/tools/util/include/compat/launch.hpp new file mode 100644 index 0000000000..0e0d84fa15 --- /dev/null +++ b/tools/util/include/compat/launch.hpp @@ -0,0 +1,165 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCL compatibility extension + * + * launch.hpp + * + * Description: + * launch functionality for the SYCL compatibility extension + **************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace compat { + +namespace detail { + +template +constexpr size_t getArgumentCount(R (*f)(Types...)) { + return sizeof...(Types); +} + +template +sycl::nd_range<3> transform_nd_range(const sycl::nd_range &range) { + sycl::range global_range = range.get_global_range(); + sycl::range local_range = range.get_local_range(); + if constexpr (Dim == 3) { + return range; + } else if constexpr (Dim == 2) { + return sycl::nd_range<3>{{1, global_range[0], global_range[1]}, + {1, local_range[0], local_range[1]}}; + } + return sycl::nd_range<3>{{1, 1, global_range[0]}, {1, 1, local_range[0]}}; +} + +template +std::enable_if_t, sycl::event> +launch(const sycl::nd_range<3> &range, sycl::queue q, Args... args) { + static_assert(detail::getArgumentCount(F) == sizeof...(args), + "Wrong number of arguments to SYCL kernel"); + static_assert( + std::is_same, void>::value, + "SYCL kernels should return void"); + + return q.parallel_for( + range, [=](sycl::nd_item<3>) { [[clang::always_inline]] F(args...); }); +} + +} // namespace detail + +template +inline sycl::nd_range compute_nd_range(sycl::range global_size_in, + sycl::range work_group_size) { + + if (global_size_in.size() == 0 || work_group_size.size() == 0) { + throw std::invalid_argument("Global or local size is zero!"); + } + for (size_t i = 0; i < Dim; ++i) { + if (global_size_in[i] < work_group_size[i]) + throw std::invalid_argument("Work group size larger than global size"); + } + + auto global_size = + ((global_size_in + work_group_size - 1) / work_group_size) * + work_group_size; + return {global_size, work_group_size}; +} + +inline sycl::nd_range<1> compute_nd_range(int global_size_in, + int work_group_size) { + return compute_nd_range<1>(global_size_in, work_group_size); +} + +template +std::enable_if_t, sycl::event> +launch(const sycl::nd_range &range, sycl::queue q, Args... args) { + return detail::launch(detail::transform_nd_range(range), q, args...); +} + +template +std::enable_if_t, sycl::event> +launch(const sycl::nd_range &range, Args... args) { + return launch(range, get_default_queue(), args...); +} + +// Alternative launch through dim3 objects +template +std::enable_if_t, sycl::event> +launch(const dim3 &grid, const dim3 &threads, sycl::queue q, Args... args) { + return launch(sycl::nd_range<3>{grid * threads, threads}, q, args...); +} + +template +std::enable_if_t, sycl::event> +launch(const dim3 &grid, const dim3 &threads, Args... args) { + return launch(grid, threads, get_default_queue(), args...); +} + +} // namespace compat + +namespace compat::experimental { + +namespace detail { + +template +sycl::event launch(LaunchPolicy launch_policy, sycl::queue q, Args... args) { + static_assert(compat::args_compatible, + "Mismatch between device function signature and supplied " + "arguments. Have you correctly handled local memory/char*?"); + + sycl_exp::launch_config config(launch_policy.get_range(), + launch_policy.get_launch_properties()); + + return sycl_exp::submit_with_event(q, [&](sycl::handler &cgh) { + auto KernelFunctor = build_kernel_functor(cgh, launch_policy, args...); + if constexpr (compat::detail::is_range_v< + typename LaunchPolicy::RangeT>) { + parallel_for(cgh, config, KernelFunctor); + } else { + static_assert( + compat::detail::is_nd_range_v); + nd_launch(cgh, config, KernelFunctor); + } + }); +} + +} + + +template +sycl::event launch(LaunchPolicy launch_policy, sycl::queue q, Args... args) { + static_assert(detail::is_launch_policy_v); + return detail::launch(launch_policy, q, args...); +} + +template +sycl::event launch(LaunchPolicy launch_policy, Args... args) { + static_assert(detail::is_launch_policy_v); + return launch(launch_policy, get_default_queue(), args...); +} + +} // namespace compat::experimental diff --git a/tools/util/include/compat/launch_policy.hpp b/tools/util/include/compat/launch_policy.hpp new file mode 100644 index 0000000000..b7b7a01da2 --- /dev/null +++ b/tools/util/include/compat/launch_policy.hpp @@ -0,0 +1,273 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCL compatibility extension + * + * launch.hpp + * + * Description: + * launch functionality for the SYCL compatibility extension + **************************************************************************/ + +#pragma once + +#include "sycl/ext/oneapi/experimental/enqueue_functions.hpp" +#include "sycl/ext/oneapi/properties/properties.hpp" +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace compat { +namespace experimental { + +namespace sycl_exp = sycl::ext::oneapi::experimental; + +// Wrapper for kernel sycl_exp::properties +template struct kernel_properties { + static_assert(sycl_exp::is_property_list_v); + using Props = Properties; + + template + kernel_properties(Props... properties) : props{properties...} {} + + template + kernel_properties(sycl_exp::properties properties) + : props{properties} {} + + Properties props; +}; + +template ::value, void>> +kernel_properties(Props... props) + -> kernel_properties; + +template +kernel_properties(sycl_exp::properties props) + -> kernel_properties>; + +// Wrapper for launch sycl_exp::properties +template struct launch_properties { + static_assert(sycl_exp::is_property_list_v); + using Props = Properties; + + template + launch_properties(Props... properties) : props{properties...} {} + + template + launch_properties(sycl_exp::properties properties) + : props{properties} {} + + Properties props; +}; + +template ::value, void>> +launch_properties(Props... props) + -> launch_properties; + +template +launch_properties(sycl_exp::properties props) + -> launch_properties>; + +// Wrapper for local memory size +struct local_mem_size { + local_mem_size(size_t size = 0) : size{size} {}; + size_t size; +}; + +// launch_policy is constructed by the user & passed to `compat_exp::launch` +template +class launch_policy { + static_assert(sycl_exp::is_property_list_v); + static_assert(sycl_exp::is_property_list_v); + static_assert(compat::detail::is_range_or_nd_range_v); + static_assert(compat::detail::is_nd_range_v || !LocalMem, + "sycl::range kernel launches are incompatible with local " + "memory usage!"); + +public: + using KPropsT = KProps; + using LPropsT = LProps; + using RangeT = Range; + static constexpr bool HasLocalMem = LocalMem; + +private: + launch_policy() = default; + + template + launch_policy(Ts... ts) + : _kernel_properties{detail::property_getter< + kernel_properties, kernel_properties, std::tuple>()( + std::tuple(ts...))}, + _launch_properties{detail::property_getter< + launch_properties, launch_properties, std::tuple>()( + std::tuple(ts...))}, + _local_mem_size{ + detail::local_mem_getter>()( + std::tuple(ts...))} { + check_variadic_args(ts...); + } + + template void check_variadic_args(Ts...) { + static_assert( + std::conjunction_v, + detail::is_launch_properties, + detail::is_local_mem_size>...>, + "Received an unexpected argument to ctor. Did you forget to wrap " + "in " + "compat::kernel_properties, launch_properties, local_mem_size?"); + } + +public: + template + launch_policy(Range range, Ts... ts) : launch_policy(ts...) { + _range = range; + check_variadic_args(ts...); + } + + template + launch_policy(dim3 global_range, Ts... ts) : launch_policy(ts...) { + _range = Range{global_range}; + check_variadic_args(ts...); + } + + template + launch_policy(dim3 global_range, dim3 local_range, Ts... ts) + : launch_policy(ts...) { + _range = Range{global_range * local_range, local_range}; + check_variadic_args(ts...); + } + + KProps get_kernel_properties() { return _kernel_properties.props; } + LProps get_launch_properties() { return _launch_properties.props; } + size_t get_local_mem_size() { return _local_mem_size.size; } + Range get_range() { return _range; } + +private: + Range _range; + kernel_properties _kernel_properties; + launch_properties _launch_properties; + local_mem_size _local_mem_size; +}; + +// Deduction guides for launch_policy +template +launch_policy(Range, Ts...) -> launch_policy< + Range, detail::properties_or_empty, + detail::properties_or_empty, + detail::has_type>::value>; + +template +launch_policy(sycl::range, sycl::range, Ts...) -> launch_policy< + sycl::nd_range, detail::properties_or_empty, + detail::properties_or_empty, + detail::has_type>::value>; + +template +launch_policy(dim3, Ts...) -> launch_policy< + sycl::range<3>, detail::properties_or_empty, + detail::properties_or_empty, + detail::has_type>::value>; + +template +launch_policy(dim3, dim3, Ts...) -> launch_policy< + sycl::nd_range<3>, detail::properties_or_empty, + detail::properties_or_empty, + detail::has_type>::value>; + +namespace detail { +// Custom std::apply helpers to enable inlining +template +__compat_inline__ constexpr void apply_expand(F &&f, Tuple &&t, + std::index_sequence) { + [[clang::always_inline]] std::forward(f)( + get(std::forward(t))...); +} + +template +__compat_inline__ constexpr void apply_helper(F &&f, Tuple &&t) { + apply_expand( + std::forward(f), std::forward(t), + std::make_index_sequence>>{}); +} + +template +struct KernelFunctor { + KernelFunctor(KProps kernel_props, Args... args) + : _kernel_properties{kernel_props}, + _argument_tuple(std::make_tuple(args...)) {} + + KernelFunctor(KProps kernel_props, sycl::local_accessor local_acc, + Args... args) + : _kernel_properties{kernel_props}, _local_acc{local_acc}, + _argument_tuple(std::make_tuple(args...)) {} + + auto get(sycl_exp::properties_tag) const { return _kernel_properties; } + + __compat_inline__ void + operator()(compat::detail::range_to_item_t) const { + if constexpr (HasLocalMem) { + char *local_mem_ptr = static_cast( + _local_acc.template get_multi_ptr() + .get()); + apply_helper( + [lmem_ptr = local_mem_ptr](auto &&...args) { + [[clang::always_inline]] F(args..., lmem_ptr); + }, + _argument_tuple); + } else { + apply_helper([](auto &&...args) { [[clang::always_inline]] F(args...); }, + _argument_tuple); + } + } + + KProps _kernel_properties; + std::tuple _argument_tuple; + std::conditional_t, std::monostate> + _local_acc; // monostate for empty type +}; + +//==================================================================== +// This helper function avoids 2 nested `if constexpr` in detail::launch +template +auto build_kernel_functor(sycl::handler &cgh, LaunchPolicy launch_policy, + Args... args) + -> KernelFunctor { + if constexpr (LaunchPolicy::HasLocalMem) { + sycl::local_accessor local_memory( + launch_policy.get_local_mem_size(), cgh); + return KernelFunctor( + launch_policy.get_kernel_properties(), local_memory, args...); + } else { + return KernelFunctor( + launch_policy.get_kernel_properties(), args...); + } +} + +} // namespace detail +} // namespace experimental +} // namespace compat diff --git a/tools/util/include/compat/math.hpp b/tools/util/include/compat/math.hpp new file mode 100644 index 0000000000..536a17a005 --- /dev/null +++ b/tools/util/include/compat/math.hpp @@ -0,0 +1,2386 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCL compatibility extension + * + * math.hpp + * + * Description: + * math utilities for the SYCL compatibility extension. + **************************************************************************/ + +// The original source was under the license below: +//==---- math.hpp ---------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include + +// TODO(compat-lib-reviewers): this should not be required +#ifndef SYCL_EXT_ONEAPI_COMPLEX +#define SYCL_EXT_ONEAPI_COMPLEX +#endif + +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +#include +#endif +#include +#include + +namespace compat { +namespace detail { + +namespace complex_namespace = sycl::ext::oneapi::experimental; + +template +using complex_type = detail::complex_namespace::complex; + +template +constexpr bool is_int32_type = std::is_same_v, int32_t> || + std::is_same_v, uint32_t>; + +// Helper constexpr bool to avoid ugly macros where possible +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +constexpr bool support_bfloat16_math = true; +#else +constexpr bool support_bfloat16_math = false; +#endif + +template +inline ValueT clamp(ValueT val, ValueT min_val, ValueT max_val) { + return sycl::clamp(val, min_val, max_val); +} +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +// TODO(compat-lib-reviewers): Follow the process to add this (& other math +// fns) to the bfloat16 math function extension. If added, remove this +// functionality from the header. +template <> +inline sycl::ext::oneapi::bfloat16 clamp(sycl::ext::oneapi::bfloat16 val, + sycl::ext::oneapi::bfloat16 min_val, + sycl::ext::oneapi::bfloat16 max_val) { + if (val < min_val) + return min_val; + if (val > max_val) + return max_val; + return val; +} + +template +inline std::enable_if_t, + sycl::vec> +clamp(sycl::vec val, sycl::vec min_val, + sycl::vec max_val) { + return [&val, &min_val, &max_val](std::integer_sequence) { + return sycl::vec{ + clamp(val[I], min_val[I], max_val[I])...}; + }(std::make_integer_sequence{}); +} + +template +inline std::enable_if_t, + sycl::marray> +clamp(sycl::marray val, sycl::marray min_val, + sycl::marray max_val) { + return [&val, &min_val, &max_val](std::index_sequence) { + return sycl::marray{ + clamp(val[I], min_val[I], max_val[I])...}; + }(std::make_index_sequence{}); +} +#endif + +template +class vectorized_binary { +public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) { + VecT v4; + for (size_t i = 0; i < v4.size(); ++i) { + v4[i] = binary_op(a[i], b[i]); + } + return v4; + } +}; + +template +class vectorized_binary< + VecT, BinaryOperation, + std::void_t>> { +public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) { + return binary_op(a, b).template as(); + } +}; + +/// Extend the 'val' to 'bit' size, zero extend for unsigned int and signed +/// extend for signed int. Returns a signed integer type. +template +inline auto zero_or_signed_extend(ValueT val, unsigned bit) { + static_assert(std::is_integral_v); + if constexpr (sizeof(ValueT) == 4) { + assert(bit < 64 && + "When extending int32 value, bit must be smaller than 64."); + if constexpr (std::is_signed_v) + return int64_t(val) << (64 - bit) >> (64 - bit); + else + return int64_t(val); + } else if constexpr (sizeof(ValueT) == 2) { + assert(bit < 32 && + "When extending int16 value, bit must be smaller than 32."); + if constexpr (std::is_signed_v) + return int32_t(val) << (32 - bit) >> (32 - bit); + else + return int32_t(val); + } else if constexpr (sizeof(ValueT) == 1) { + assert(bit < 16 && + "When extending int8 value, bit must be smaller than 16."); + if constexpr (std::is_signed_v) + return int16_t(val) << (16 - bit) >> (16 - bit); + else + return int16_t(val); + } else { + static_assert(sizeof(ValueT) == 8); + assert(bit < 64 && "Cannot extend int64 value."); + return static_cast(val); + } +} + +template +inline constexpr RetT extend_binary(AT a, BT b, BinaryOperation binary_op) { + const int64_t extend_a = zero_or_signed_extend(a, 33); + const int64_t extend_b = zero_or_signed_extend(b, 33); + const int64_t ret = binary_op(extend_a, extend_b); + if constexpr (needSat) + return detail::clamp(ret, std::numeric_limits::min(), + std::numeric_limits::max()); + return ret; +} + +template +inline constexpr RetT extend_binary(AT a, BT b, CT c, + BinaryOperation1 binary_op, + BinaryOperation2 second_op) { + const int64_t extend_a = zero_or_signed_extend(a, 33); + const int64_t extend_b = zero_or_signed_extend(b, 33); + int64_t extend_temp = + zero_or_signed_extend(binary_op(extend_a, extend_b), 34); + if constexpr (needSat) + extend_temp = + detail::clamp(extend_temp, std::numeric_limits::min(), + std::numeric_limits::max()); + const int64_t extend_c = zero_or_signed_extend(c, 33); + return second_op(extend_temp, extend_c); +} + +template sycl::vec extract_and_extend2(T a) { + sycl::vec ret; + sycl::vec va{a}; + using IntT = std::conditional_t, int16_t, uint16_t>; + auto v = va.template as>(); + ret[0] = zero_or_signed_extend(v[0], 17); + ret[1] = zero_or_signed_extend(v[1], 17); + return ret; +} + +template sycl::vec extract_and_extend4(T a) { + sycl::vec ret; + sycl::vec va{a}; + using IntT = std::conditional_t, int8_t, uint8_t>; + auto v = va.template as>(); + ret[0] = zero_or_signed_extend(v[0], 9); + ret[1] = zero_or_signed_extend(v[1], 9); + ret[2] = zero_or_signed_extend(v[2], 9); + ret[3] = zero_or_signed_extend(v[3], 9); + return ret; +} + +template +inline constexpr RetT extend_vbinary2(AT a, BT b, RetT c, + BinaryOperation binary_op) { + static_assert(is_int32_type && is_int32_type && is_int32_type); + sycl::vec extend_a = extract_and_extend2(a); + sycl::vec extend_b = extract_and_extend2(b); + sycl::vec temp{binary_op(extend_a[0], extend_b[0]), + binary_op(extend_a[1], extend_b[1])}; + using IntT = std::conditional_t, int16_t, uint16_t>; + + if constexpr (NeedSat) { + int32_t min_val = 0, max_val = 0; + min_val = std::numeric_limits::min(); + max_val = std::numeric_limits::max(); + temp = detail::clamp(temp, sycl::vec(min_val), + sycl::vec(max_val)); + } + if constexpr (NeedAdd) { + return temp[0] + temp[1] + c; + } + return sycl::vec{temp[0], temp[1]}.template as>(); +} + +template +inline constexpr RetT extend_vbinary4(AT a, BT b, RetT c, + BinaryOperation binary_op) { + static_assert(is_int32_type && is_int32_type && is_int32_type); + sycl::vec extend_a = extract_and_extend4(a); + sycl::vec extend_b = extract_and_extend4(b); + sycl::vec temp{ + binary_op(extend_a[0], extend_b[0]), binary_op(extend_a[1], extend_b[1]), + binary_op(extend_a[2], extend_b[2]), binary_op(extend_a[3], extend_b[3])}; + using IntT = std::conditional_t, int8_t, uint8_t>; + + if constexpr (NeedSat) { + int16_t min_val = 0, max_val = 0; + min_val = std::numeric_limits::min(); + max_val = std::numeric_limits::max(); + temp = detail::clamp(temp, sycl::vec(min_val), + sycl::vec(max_val)); + } + if constexpr (NeedAdd) { + return temp[0] + temp[1] + temp[2] + temp[3] + c; + } + + return sycl::vec{temp[0], temp[1], temp[2], temp[3]} + .template as>(); +} + +template inline bool isnan(const ValueT a) { + if constexpr (std::is_same_v) { + static_assert(detail::support_bfloat16_math); + return sycl::ext::oneapi::experimental::isnan(a); + } else { + return sycl::isnan(a); + } +} + +// FIXME(compat-lib-reviewers): move bfe outside detail once perf is +// improved & semantics understood +/// Bitfield-extract. +/// +/// \tparam T The type of \param source value, must be an integer. +/// \param source The source value to extracting. +/// \param bit_start The position to start extracting. +/// \param num_bits The number of bits to extracting. +template +inline T bfe(const T source, const uint32_t bit_start, + const uint32_t num_bits) { + static_assert(std::is_unsigned_v); + // FIXME(compat-lib-reviewers): This ternary was added to catch a case + // which may be undefined anyway. Consider that we are losing perf here. + const T mask = + num_bits >= std::numeric_limits::digits * sizeof(T) + ? static_cast(-1) + : ((static_cast(1) << num_bits) - 1); + return (source >> bit_start) & mask; +} + +} // namespace detail + +/// Bitfield-extract with boundary checking. +/// +/// Extract bit field from \param source and return the zero or sign-extended +/// result. Source \param bit_start gives the bit field starting bit position, +/// and source \param num_bits gives the bit field length in bits. +/// +/// The result is padded with the sign bit of the extracted field. If `num_bits` +/// is zero, the result is zero. If the start position is beyond the msb of the +/// input, the result is filled with the replicated sign bit of the extracted +/// field. +/// +/// \tparam T The type of \param source value, must be an integer. +/// \param source The source value to extracting. +/// \param bit_start The position to start extracting. +/// \param num_bits The number of bits to extracting. +template +inline T bfe_safe(const T source, const uint32_t bit_start, + const uint32_t num_bits) { + static_assert(std::is_integral_v); +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + int32_t res{}; + asm volatile("bfe.s32 %0, %1, %2, %3;" + : "=r"(res) + : "r"((int32_t)source), "r"(bit_start), "r"(num_bits)); + return res; + } else if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + uint32_t res{}; + asm volatile("bfe.u32 %0, %1, %2, %3;" + : "=r"(res) + : "r"((uint32_t)source), "r"(bit_start), "r"(num_bits)); + return res; + } else if constexpr (std::is_same_v) { + T res{}; + asm volatile("bfe.s64 %0, %1, %2, %3;" + : "=l"(res) + : "l"(source), "r"(bit_start), "r"(num_bits)); + return res; + } else if constexpr (std::is_same_v) { + T res{}; + asm volatile("bfe.u64 %0, %1, %2, %3;" + : "=l"(res) + : "l"(source), "r"(bit_start), "r"(num_bits)); + return res; + } +#endif + const uint32_t bit_width = + std::numeric_limits::digits * sizeof(T); + const uint32_t pos = std::min(bit_start, bit_width); + const uint32_t len = std::min(pos + num_bits, bit_width) - pos; + if constexpr (std::is_signed_v) { + // FIXME(compat-lib-reviewers): As above, catching a case whose result + // is undefined and likely losing perf. + const T mask = len >= bit_width ? T{-1} : static_cast((T{1} << len) - 1); + + // Find the sign-bit, the result is padded with the sign bit of the + // extracted field. + // Note if requested num_bits==0, we return zero via sign_bit=0 + const uint32_t sign_bit_pos = std::min(pos + len - 1, bit_width - 1); + const T sign_bit = num_bits != 0 && ((source >> sign_bit_pos) & 1); + const T sign_bit_padding = (-sign_bit & ~mask); + return ((source >> pos) & mask) | sign_bit_padding; + } else { + return compat::detail::bfe(source, pos, len); + } +} + +namespace detail { +// FIXME(compat-lib-reviewers): move bfi outside detail once perf is +// improved & semantics understood +/// Bitfield-insert. +/// +/// \tparam T The type of \param x and \param y , must be an unsigned integer. +/// \param x The source of the bitfield. +/// \param y The source where bitfield is inserted. +/// \param bit_start The position to start insertion. +/// \param num_bits The number of bits to insertion. +template +inline T bfi(const T x, const T y, const uint32_t bit_start, + const uint32_t num_bits) { + static_assert(std::is_unsigned_v); + constexpr unsigned bit_width = + std::numeric_limits::digits * sizeof(T); + + // if bit_start > bit_width || len == 0, should return y. + const T ignore_bfi = static_cast(bit_start > bit_width || num_bits == 0); + T extract_bitfield_mask = (static_cast(~T{0}) >> (bit_width - num_bits)) + << bit_start; + T clean_bitfield_mask = ~extract_bitfield_mask; + return (y & (-ignore_bfi | clean_bitfield_mask)) | + (~-ignore_bfi & ((x << bit_start) & extract_bitfield_mask)); +} +} // namespace detail + +/// Bitfield-insert with boundary checking. +/// +/// Align and insert a bit field from \param x into \param y . Source \param +/// bit_start gives the starting bit position for the insertion, and source +/// \param num_bits gives the bit field length in bits. +/// +/// \tparam T The type of \param x and \param y , must be an unsigned integer. +/// \param x The source of the bitfield. +/// \param y The source where bitfield is inserted. +/// \param bit_start The position to start insertion. +/// \param num_bits The number of bits to insertion. +template +inline T bfi_safe(const T x, const T y, const uint32_t bit_start, + const uint32_t num_bits) { + static_assert(std::is_unsigned_v); +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + uint32_t res{}; + asm volatile("bfi.b32 %0, %1, %2, %3, %4;" + : "=r"(res) + : "r"((uint32_t)x), "r"((uint32_t)y), "r"(bit_start), + "r"(num_bits)); + return res; + } else if constexpr (std::is_same_v) { + uint64_t res{}; + asm volatile("bfi.b64 %0, %1, %2, %3, %4;" + : "=l"(res) + : "l"(x), "l"(y), "r"(bit_start), "r"(num_bits)); + return res; + } +#endif + constexpr unsigned bit_width = + std::numeric_limits::digits * sizeof(T); + const uint32_t pos = std::min(bit_start, bit_width); + const uint32_t len = std::min(pos + num_bits, bit_width) - pos; + return compat::detail::bfi(x, y, pos, len); +} + +/// Emulated function for __funnelshift_l +inline unsigned int funnelshift_l(unsigned int low, unsigned int high, + unsigned int shift) { + return (sycl::upsample(high, low) << (shift & 31U)) >> 32; +} + +/// Emulated function for __funnelshift_lc +inline unsigned int funnelshift_lc(unsigned int low, unsigned int high, + unsigned int shift) { + return (sycl::upsample(high, low) << sycl::min(shift, 32U)) >> 32; +} + +/// Emulated function for __funnelshift_r +inline unsigned int funnelshift_r(unsigned int low, unsigned int high, + unsigned int shift) { + return (sycl::upsample(high, low) >> (shift & 31U)) & 0xFFFFFFFF; +} + +/// Emulated function for __funnelshift_rc +inline unsigned int funnelshift_rc(unsigned int low, unsigned int high, + unsigned int shift) { + return (sycl::upsample(high, low) >> sycl::min(shift, 32U)) & 0xFFFFFFFF; +} + +/// Compute fast_length for variable-length array +/// \param [in] a The array +/// \param [in] len Length of the array +/// \returns The computed fast_length +inline float fast_length(const float *a, int len) { + switch (len) { + case 1: + return sycl::fast_length(a[0]); + case 2: + return sycl::fast_length(sycl::float2(a[0], a[1])); + case 3: + return sycl::fast_length(sycl::float3(a[0], a[1], a[2])); + case 4: + return sycl::fast_length(sycl::float4(a[0], a[1], a[2], a[3])); + case 0: + return 0; + default: + float f = 0; + for (int i = 0; i < len; ++i) + f += a[i] * a[i]; + return sycl::sqrt(f); + } +} + +/// Calculate the square root of the input array. +/// \param [in] a The array pointer +/// \param [in] len Length of the array +/// \returns The square root +template +inline ValueT length(const ValueT *a, const int len) { + switch (len) { + case 1: + return a[0]; + case 2: + return sycl::length(sycl::vec(a[0], a[1])); + case 3: + return sycl::length(sycl::vec(a[0], a[1], a[2])); + case 4: + return sycl::length(sycl::vec(a[0], a[1], a[2], a[3])); + default: + ValueT ret = 0; + for (int i = 0; i < len; ++i) + ret += a[i] * a[i]; + return sycl::sqrt(ret); + } +} + +/// Performs comparison. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t< + std::is_same_v, bool>, + bool> +compare(const ValueT a, const ValueT b, const BinaryOperation binary_op) { + return binary_op(a, b); +} +template +inline std::enable_if_t< + std::is_same_v, ValueT, ValueT>, + bool>, + bool> +compare(const ValueT a, const ValueT b, const std::not_equal_to<> binary_op) { + return !detail::isnan(a) && !detail::isnan(b) && binary_op(a, b); +} + +/// Performs 2 element comparison. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +compare(const ValueT a, const ValueT b, const BinaryOperation binary_op) { + return {compare(a[0], b[0], binary_op), compare(a[1], b[1], binary_op)}; +} + +/// Performs unordered comparison. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t< + std::is_same_v, bool>, + bool> +unordered_compare(const ValueT a, const ValueT b, + const BinaryOperation binary_op) { + return detail::isnan(a) || detail::isnan(b) || binary_op(a, b); +} + +/// Performs 2 element unordered comparison. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +unordered_compare(const ValueT a, const ValueT b, + const BinaryOperation binary_op) { + return {unordered_compare(a[0], b[0], binary_op), + unordered_compare(a[1], b[1], binary_op)}; +} + +/// Performs 2 element comparison and return true if both results are true. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +compare_both(const ValueT a, const ValueT b, const BinaryOperation binary_op) { + return compare(a[0], b[0], binary_op) && compare(a[1], b[1], binary_op); +} + +/// Performs 2 element unordered comparison and return true if both results are +/// true. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +unordered_compare_both(const ValueT a, const ValueT b, + const BinaryOperation binary_op) { + return unordered_compare(a[0], b[0], binary_op) && + unordered_compare(a[1], b[1], binary_op); +} + +/// Performs 2 elements comparison, compare result of each element is 0 (false) +/// or 0xffff (true), returns an unsigned int by composing compare result of two +/// elements. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +compare_mask(const ValueT a, const ValueT b, const BinaryOperation binary_op) { + // Since compare returns 0 or 1, -compare will be 0x00000000 or 0xFFFFFFFF + return ((-compare(a[0], b[0], binary_op)) & 0xFFFF) | + ((-compare(a[1], b[1], binary_op)) << 16u); +} + +/// Performs 2 elements unordered comparison, compare result of each element is +/// 0 (false) or 0xffff (true), returns an unsigned int by composing compare +/// result of two elements. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op functor that implements the binary operation +/// \returns the comparison result +template +inline std::enable_if_t +unordered_compare_mask(const ValueT a, const ValueT b, + const BinaryOperation binary_op) { + return ((-unordered_compare(a[0], b[0], binary_op)) & 0xFFFF) | + ((-unordered_compare(a[1], b[1], binary_op)) << 16); +} + +/// Compute vectorized max for two values, with each value treated as a vector +/// type \p S +/// \param [in] S The type of the vector +/// \param [in] T The type of the original values +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The vectorized max of the two values +template inline T vectorized_max(T a, T b) { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.template as(); + auto v3 = v1.template as(); + v2 = sycl::max(v2, v3); + v0 = v2.template as>(); + return v0; +} + +/// Compute vectorized min for two values, with each value treated as a vector +/// type \p S +/// \param [in] S The type of the vector +/// \param [in] T The type of the original values +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The vectorized min of the two values +template inline T vectorized_min(T a, T b) { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.template as(); + auto v3 = v1.template as(); + v2 = sycl::min(v2, v3); + v0 = v2.template as>(); + return v0; +} + +/// Compute vectorized unary operation for a value, with the value treated as a +/// vector type \p VecT. +/// \tparam [in] VecT The type of the vector +/// \tparam [in] UnaryOperation The unary operation class +/// \param [in] a The input value +/// \returns The vectorized unary operation value of the input value +template +inline unsigned vectorized_unary(unsigned a, const UnaryOperation unary_op) { + sycl::vec v0{a}; + auto v1 = v0.as(); + auto v2 = unary_op(v1); + v0 = v2.template as>(); + return v0; +} + +/// Compute vectorized absolute difference for two values without modulo +/// overflow, with each value treated as a vector type \p VecT. +/// \tparam [in] VecT The type of the vector +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The vectorized absolute difference of the two values +template +inline unsigned vectorized_sum_abs_diff(unsigned a, unsigned b) { + sycl::vec v0{a}, v1{b}; + // Need convert element type to wider signed type to avoid overflow. + auto v2 = v0.as().template convert(); + auto v3 = v1.as().template convert(); + auto v4 = sycl::abs_diff(v2, v3); + unsigned sum = 0; + for (size_t i = 0; i < v4.size(); ++i) { + sum += v4[i]; + } + return sum; +} + +/// Compute vectorized isgreater for two values, with each value treated as a +/// vector type \p S +/// \param [in] S The type of the vector +/// \param [in] T The type of the original values +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The vectorized greater than of the two values +template inline T vectorized_isgreater(T a, T b) { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.template as(); + auto v3 = v1.template as(); + auto v4 = sycl::isgreater(v2, v3); + v0 = v4.template as>(); + return v0; +} + +/// Compute vectorized isgreater for two unsigned int values, with each value +/// treated as a vector of two unsigned short +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The vectorized greater than of the two values +template <> +inline unsigned vectorized_isgreater(unsigned a, + unsigned b) { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.template as(); + auto v3 = v1.template as(); + sycl::ushort2 v4; + v4[0] = v2[0] > v3[0]; + v4[1] = v2[1] > v3[1]; + v0 = v4.template as>(); + return v0; +} + +/// Returns min(max(val, min_val), max_val) +/// \param [in] val The input value +/// \param [in] min_val The minimum value +/// \param [in] max_val The maximum value +/// \returns the value between min_val and max_val +template +inline ValueT clamp(ValueT val, ValueT min_val, ValueT max_val) { + return detail::clamp(val, min_val, max_val); +} + +/// Determine whether 2 element value is NaN. +/// \param [in] a The input value +/// \returns the comparison result +template +inline std::enable_if_t isnan(const ValueT a) { + return {detail::isnan(a[0]), detail::isnan(a[1])}; +} + +/// cbrt function wrapper. +template +inline std::enable_if_t || + std::is_same_v, + ValueT> +cbrt(ValueT val) { + return sycl::cbrt(static_cast(val)); +} + +// min/max function overloads. +// For floating-point types, `float` or `double` arguments are acceptable. +// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or +// `std::int64_t` type arguments are acceptable. +// sycl::half supported as well, and sycl::ext::oneapi::bfloat16 if available. +template +inline std::enable_if_t && + std::is_integral_v, + std::common_type_t> +min(ValueT a, ValueU b) { + return sycl::min(static_cast>(a), + static_cast>(b)); +} + +template +inline std::enable_if_t && + compat::is_floating_point_v, + std::common_type_t> +min(ValueT a, ValueU b) { + if constexpr (std::is_same_v, + sycl::ext::oneapi::bfloat16>) { + static_assert(detail::support_bfloat16_math); + return sycl::ext::oneapi::experimental::fmin( + static_cast>(a), + static_cast>(b)); + } else { + return sycl::fmin(static_cast>(a), + static_cast>(b)); + } +} + +template +inline std::enable_if_t && + std::is_integral_v, + std::common_type_t> +max(ValueT a, ValueU b) { + return sycl::max(static_cast>(a), + static_cast>(b)); +} +template +inline std::enable_if_t && + compat::is_floating_point_v, + std::common_type_t> +max(ValueT a, ValueU b) { + if constexpr (std::is_same_v, + sycl::ext::oneapi::bfloat16>) { + static_assert(detail::support_bfloat16_math); + return sycl::ext::oneapi::experimental::fmax( + static_cast>(a), + static_cast>(b)); + } else { + return sycl::fmax(static_cast>(a), + static_cast>(b)); + } +} + +/// Performs 2 elements comparison and returns the bigger one. If either of +/// inputs is NaN, then return NaN. +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns the bigger value +template +inline std::common_type_t fmax_nan(const ValueT a, + const ValueU b) { + if (detail::isnan(a) || detail::isnan(b)) + return NAN; + return compat::max(a, b); +} + +template +inline sycl::vec, 2> +fmax_nan(const sycl::vec a, const sycl::vec b) { + return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])}; +} + +template +inline sycl::marray, 2> +fmax_nan(const sycl::marray a, const sycl::marray b) { + return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])}; +} + +/// Performs 2 elements comparison and returns the smaller one. If either of +/// inputs is NaN, then return NaN. +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns the smaller value +template +inline std::common_type_t fmin_nan(const ValueT a, + const ValueU b) { + if (detail::isnan(a) || detail::isnan(b)) + return NAN; + return compat::min(a,b); +} + +template +inline sycl::vec, 2> +fmin_nan(const sycl::vec a, const sycl::vec b) { + return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])}; +} + +template +inline sycl::marray, 2> +fmin_nan(const sycl::marray a, const sycl::marray b) { + return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])}; +} + +// pow functions overload. +inline float pow(const float a, const int b) { return sycl::pown(a, b); } +inline double pow(const double a, const int b) { return sycl::pown(a, b); } + +template +inline typename std::enable_if_t, ValueT> +pow(const ValueT a, const ValueU b) { + return sycl::pow(a, static_cast(b)); +} +// TODO(compat-lib-reviewers) calling pow with non-floating point values +// is currently defaulting to double, which fails on devices without +// aspect::fp64. This has to be properly documented, and maybe changed to +// support all devices. +template +inline typename std::enable_if_t, double> +pow(const ValueT a, const ValueU b) { + return sycl::pow(static_cast(a), static_cast(b)); +} + +/// Performs relu saturation. +/// \param [in] a The input value +/// \returns the relu saturation result +template inline ValueT relu(const ValueT a) { + if constexpr (compat::is_floating_point_v) + if (detail::isnan(a)) + return a; + if (a < ValueT(0)) + return ValueT(0); + return a; +} +template +inline sycl::vec +relu(const sycl::vec a) { + sycl::vec ret; + for (int i = 0; i < NumElements; ++i) + ret[i] = relu(a[i]); + return ret; +} +template +inline sycl::marray relu(const sycl::marray a) { + return {relu(a[0]), relu(a[1])}; +} + +/// Computes the multiplication of two complex numbers. +/// \tparam T Complex element type +/// \param [in] x The first input complex number +/// \param [in] y The second input complex number +/// \returns The result +template +sycl::vec cmul(sycl::vec x, sycl::vec y) { + sycl::ext::oneapi::experimental::complex t1(x[0], x[1]), t2(y[0], y[1]); + t1 = t1 * t2; + return sycl::vec(t1.real(), t1.imag()); +} + +/// Computes the division of two complex numbers. +/// \tparam T Complex element type +/// \param [in] x The first input complex number +/// \param [in] y The second input complex number +/// \returns The result +template +sycl::vec cdiv(sycl::vec x, sycl::vec y) { + sycl::ext::oneapi::experimental::complex t1(x[0], x[1]), t2(y[0], y[1]); + t1 = t1 / t2; + return sycl::vec(t1.real(), t1.imag()); +} + +/// Computes the magnitude of a complex number. +/// \tparam T Complex element type +/// \param [in] x The input complex number +/// \returns The result +template T cabs(sycl::vec x) { + sycl::ext::oneapi::experimental::complex t(x[0], x[1]); + return sycl::ext::oneapi::experimental::abs(t); +} + +/// Computes the complex conjugate of a complex number. +/// \tparam T Complex element type +/// \param [in] x The input complex number +/// \returns The result +template sycl::vec conj(sycl::vec x) { + sycl::ext::oneapi::experimental::complex t(x[0], x[1]); + t = conj(t); + return sycl::vec(t.real(), t.imag()); +} + +/// Performs complex number multiply addition. +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns the operation result +template +inline sycl::vec cmul_add(const sycl::vec a, + const sycl::vec b, + const sycl::vec c) { + sycl::ext::oneapi::experimental::complex t(a[0], a[1]); + sycl::ext::oneapi::experimental::complex u(b[0], b[1]); + sycl::ext::oneapi::experimental::complex v(c[0], c[1]); + t = t * u + v; + return sycl::vec{t.real(), t.imag()}; +} +template +inline sycl::marray cmul_add(const sycl::marray a, + const sycl::marray b, + const sycl::marray c) { + sycl::ext::oneapi::experimental::complex t(a[0], a[1]); + sycl::ext::oneapi::experimental::complex u(b[0], b[1]); + sycl::ext::oneapi::experimental::complex v(c[0], c[1]); + t = t * u + v; + return sycl::marray{t.real(), t.imag()}; +} + +/// A sycl::abs wrapper functors. +struct abs { + template auto operator()(const ValueT x) const { + return sycl::abs(x); + } +}; + +/// A sycl::abs_diff wrapper functors. +struct abs_diff { + template + auto operator()(const ValueT x, const ValueT y) const { + return sycl::abs_diff(x, y); + } +}; + +/// A sycl::add_sat wrapper functors. +struct add_sat { + template + auto operator()(const ValueT x, const ValueT y) const { + return sycl::add_sat(x, y); + } +}; + +/// A sycl::rhadd wrapper functors. +struct rhadd { + template + auto operator()(const ValueT x, const ValueT y) const { + return sycl::rhadd(x, y); + } +}; + +/// A sycl::hadd wrapper functors. +struct hadd { + template + auto operator()(const ValueT x, const ValueT y) const { + return sycl::hadd(x, y); + } +}; + +/// A sycl::max wrapper functors. +struct maximum { + template + auto operator()(const ValueT x, const ValueT y) const { + return sycl::max(x, y); + } + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const { + return (x >= y) ? ((*pred = true), x) : ((*pred = false), y); + } +}; + +/// A sycl::min wrapper functors. +struct minimum { + template + auto operator()(const ValueT x, const ValueT y) const { + return sycl::min(x, y); + } + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const { + return (x <= y) ? ((*pred = true), x) : ((*pred = false), y); + } +}; + +/// A sycl::sub_sat wrapper functors. +struct sub_sat { + template + auto operator()(const ValueT x, const ValueT y) const { + return sycl::sub_sat(x, y); + } +}; + +namespace detail { +struct shift_left { + template + auto operator()(const T x, const uint32_t offset) const { + return x << offset; + } +}; + +struct shift_right { + template + auto operator()(const T x, const uint32_t offset) const { + return x >> offset; + } +}; + +struct average { + template auto operator()(const T x, const T y) const { + return (x + y + (x + y >= 0)) >> 1; + } +}; + +} // namespace detail + +/// Compute vectorized binary operation value for two/four values, with each +/// treated as a vector type \p VecT. +/// \tparam [in] VecT The type of the vector +/// \tparam [in] BinaryOperation The binary operation class +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op The operation to do with the two values +/// \param [in] need_relu Whether the result need relu saturation +/// \returns The vectorized binary operation value of the two values +template +inline unsigned vectorized_binary(unsigned a, unsigned b, + const BinaryOperation binary_op, + [[maybe_unused]] bool need_relu = false) { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.as(); + auto v3 = v1.as(); + auto v4 = + detail::vectorized_binary()(v2, v3, binary_op); + if (need_relu) + v4 = relu(v4); + v0 = v4.template as>(); + return v0; +} + +/// Compute two vectorized binary operation value with pred for three values, +/// with each value treated as a 2 \p T type elements vector type. +/// +/// \tparam [in] VecT The type of the vector +/// \tparam [in] BinaryOperation1 The first binary operation class +/// \tparam [in] BinaryOperation2 The second binary operation class +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] binary_op1 The first operation to do with the first two values +/// \param [in] binary_op2 The second operation to do with the third values +/// \param [in] need_relu Whether the result need relu saturation +/// \returns The two vectorized binary operation value of the three values +template +inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, + const BinaryOperation1 binary_op1, + const BinaryOperation2 binary_op2, + bool need_relu = false) { + const auto v1 = sycl::vec(a).as(); + const auto v2 = sycl::vec(b).as(); + const auto v3 = sycl::vec(c).as(); + auto v4 = + detail::vectorized_binary()(v1, v2, binary_op1); + v4 = detail::vectorized_binary()(v4, v3, binary_op2); + if (need_relu) + v4 = relu(v4); + return v4.template as>(); +} + +/// Compute vectorized binary operation value with pred for two values, with +/// each value treated as a 2 \p T type elements vector type. +/// +/// \tparam [in] VecT The type of the vector +/// \tparam [in] BinaryOperation The binary operation class +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op The operation with pred to do with the two values +/// \param [out] pred_hi The pred pointer that pass into high halfword operation +/// \param [out] pred_lo The pred pointer that pass into low halfword operation +/// \returns The vectorized binary operation value of the two values +template +inline unsigned vectorized_binary_with_pred(unsigned a, unsigned b, + const BinaryOperation binary_op, + bool *pred_hi, bool *pred_lo) { + auto v1 = sycl::vec(a).as(); + auto v2 = sycl::vec(b).as(); + VecT ret; + ret[0] = binary_op(v1[0], v2[0], pred_lo); + ret[1] = binary_op(v1[1], v2[1], pred_hi); + return ret.template as>(); +} + +template +using dot_product_acc_t = + std::conditional_t && std::is_unsigned_v, + uint32_t, int32_t>; + +namespace detail { + +template sycl::vec extract_and_sign_or_zero_extend4(T val) { + return sycl::vec(val) + .template as, int8_t, uint8_t>, 4>>() + .template convert(); +} + +template sycl::vec extract_and_sign_or_zero_extend2(T val) { + return sycl::vec(val) + .template as, int16_t, uint16_t>, 2>>() + .template convert(); +} + +} // namespace detail + +/// Two-way dot product-accumulate. Calculate and return integer_vector2( +/// \param a) dot product integer_vector2(low16_bit( \param b)) + \param c +/// +/// \tparam [in] T1 The type of first value. +/// \tparam [in] T2 The type of second value. +/// \param [in] a The first value. +/// \param [in] b The second value. +/// \param [in] c The third value. It has type uint32_t if both T1 and T1 are +/// uint32_t else has type int32_t. +/// \return Two-way 16-bit to 8-bit dot product which is accumulated in 32-bit +/// result. +template +inline dot_product_acc_t dp2a_lo(T1 a, T2 b, + dot_product_acc_t c) { + static_assert(detail::is_int32_type && detail::is_int32_type, + "[Compat] dp2a_lo expects 32-bit integers as operands."); +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \ + defined(__SYCL_CUDA_ARCH__) && __SYCL_CUDA_ARCH__ >= 610 + dot_product_acc_t res; + if constexpr (std::is_signed_v && std::is_signed_v) { + asm volatile("dp2a.lo.s32.s32 %0, %1, %2, %3;" + : "=r"(res) + : "r"(a), "r"(b), "r"(c)); + } else if constexpr (std::is_signed_v && std::is_unsigned_v) { + asm volatile("dp2a.lo.s32.u32 %0, %1, %2, %3;" + : "=r"(res) + : "r"(a), "r"(b), "r"(c)); + } else if constexpr (std::is_unsigned_v && std::is_signed_v) { + asm volatile("dp2a.lo.u32.s32 %0, %1, %2, %3;" + : "=r"(res) + : "r"(a), "r"(b), "r"(c)); + } else { + asm volatile("dp2a.lo.u32.u32 %0, %1, %2, %3;" + : "=r"(res) + : "r"(a), "r"(b), "r"(c)); + } + return res; +#else + dot_product_acc_t res = c; + auto va = detail::extract_and_sign_or_zero_extend2(a); + auto vb = detail::extract_and_sign_or_zero_extend4(b); + res += va[0] * vb[0]; + res += va[1] * vb[1]; + return res; +#endif +} + +/// Two-way dot product-accumulate. Calculate and return integer_vector2( +/// \param a) dot product integer_vector2(high_16bit( \param b)) + \param c +/// +/// \tparam [in] T1 The type of first value. +/// \tparam [in] T2 The type of second value. +/// \param [in] a The first value. +/// \param [in] b The second value. +/// \param [in] c The third value. uint32_t if both T1 and T1 are +/// uint32_t else has type int32_t. +/// \return Two-way 16-bit to 8-bit dot product which is accumulated in 32-bit +/// result. +template +inline dot_product_acc_t dp2a_hi(T1 a, T2 b, + dot_product_acc_t c) { + static_assert(detail::is_int32_type && detail::is_int32_type, + "[Compat] dp2a_hi expects 32-bit integers as operands."); +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \ + defined(__SYCL_CUDA_ARCH__) && __SYCL_CUDA_ARCH__ >= 610 + dot_product_acc_t res; + if constexpr (std::is_signed_v && std::is_signed_v) { + asm volatile("dp2a.hi.s32.s32 %0, %1, %2, %3;" + : "=r"(res) + : "r"(a), "r"(b), "r"(c)); + } else if constexpr (std::is_signed_v && std::is_unsigned_v) { + asm volatile("dp2a.hi.s32.u32 %0, %1, %2, %3;" + : "=r"(res) + : "r"(a), "r"(b), "r"(c)); + } else if constexpr (std::is_unsigned_v && std::is_signed_v) { + asm volatile("dp2a.hi.u32.s32 %0, %1, %2, %3;" + : "=r"(res) + : "r"(a), "r"(b), "r"(c)); + } else { + asm volatile("dp2a.hi.u32.u32 %0, %1, %2, %3;" + : "=r"(res) + : "r"(a), "r"(b), "r"(c)); + } + return res; +#else + dot_product_acc_t res = c; + auto va = detail::extract_and_sign_or_zero_extend2(a); + auto vb = detail::extract_and_sign_or_zero_extend4(b); + res += va[0] * vb[2]; + res += va[1] * vb[3]; + return res; +#endif +} + +/// Four-way byte dot product-accumulate. Calculate and return integer_vector4( +/// \param a) dot product integer_vector4( \param b) + \param c +/// +/// \tparam [in] T1 The type of first value. +/// \tparam [in] T2 The type of second value. +/// \param [in] a The first value. +/// \param [in] b The second value. +/// \param [in] c The third value. It has type uint32_t if both T1 and T1 are +/// uint32_t else has type int32_t. +/// \return Four-way byte dot product which is accumulated in 32-bit result. +template +inline dot_product_acc_t dp4a(T1 a, T2 b, dot_product_acc_t c) { + static_assert(detail::is_int32_type && detail::is_int32_type, + "[Compat] dp4a expects 32-bit integers as operands."); +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \ + defined(__SYCL_CUDA_ARCH__) && __SYCL_CUDA_ARCH__ >= 610 + dot_product_acc_t res; + if constexpr (std::is_signed_v && std::is_signed_v) { + asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" + : "=r"(res) + : "r"(a), "r"(b), "r"(c)); + } else if constexpr (std::is_signed_v && std::is_unsigned_v) { + asm volatile("dp4a.s32.u32 %0, %1, %2, %3;" + : "=r"(res) + : "r"(a), "r"(b), "r"(c)); + } else if constexpr (std::is_unsigned_v && std::is_signed_v) { + asm volatile("dp4a.u32.s32 %0, %1, %2, %3;" + : "=r"(res) + : "r"(a), "r"(b), "r"(c)); + } else { + asm volatile("dp4a.u32.u32 %0, %1, %2, %3;" + : "=r"(res) + : "r"(a), "r"(b), "r"(c)); + } + return res; +#else + dot_product_acc_t res = c; + auto va = detail::extract_and_sign_or_zero_extend4(a); + auto vb = detail::extract_and_sign_or_zero_extend4(b); + res += va[0] * vb[0]; + res += va[1] * vb[1]; + res += va[2] * vb[2]; + res += va[3] * vb[3]; + return res; +#endif +} + +/// Extend \p a and \p b to 33 bit and add them. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The extend addition of the two values +template +inline constexpr RetT extend_add(AT a, BT b) { + return detail::extend_binary(a, b, std::plus()); +} + +/// Extend Inputs to 33 bit, add \p a, \p b, then do \p second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The extend addition of \p a, \p b and \p second_op with \p c +template +inline constexpr RetT extend_add(AT a, BT b, CT c, BinaryOperation second_op) { + return detail::extend_binary(a, b, c, std::plus(), second_op); +} + +/// Extend \p a and \p b to 33 bit and add them with saturation. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The extend addition of the two values with saturation +template +inline constexpr RetT extend_add_sat(AT a, BT b) { + return detail::extend_binary(a, b, std::plus()); +} + +/// Extend Inputs to 33 bit, add \p a, \p b with saturation, then do \p +/// second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The extend addition of \p a, \p b with saturation and \p second_op +/// with \p c +template +inline constexpr RetT extend_add_sat(AT a, BT b, CT c, + BinaryOperation second_op) { + return detail::extend_binary(a, b, c, std::plus(), second_op); +} + +/// Extend \p a and \p b to 33 bit and minus them. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The extend subtraction of the two values +template +inline constexpr RetT extend_sub(AT a, BT b) { + return detail::extend_binary(a, b, std::minus()); +} + +/// Extend Inputs to 33 bit, minus \p a, \p b, then do \p second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The extend subtraction of \p a, \p b and \p second_op with \p c +template +inline constexpr RetT extend_sub(AT a, BT b, CT c, BinaryOperation second_op) { + return detail::extend_binary(a, b, c, std::minus(), second_op); +} + +/// Extend \p a and \p b to 33 bit and minus them with saturation. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The extend subtraction of the two values with saturation +template +inline constexpr RetT extend_sub_sat(AT a, BT b) { + return detail::extend_binary(a, b, std::minus()); +} + +/// Extend Inputs to 33 bit, minus \p a, \p b with saturation, then do \p +/// second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The extend subtraction of \p a, \p b with saturation and \p +/// second_op with \p c +template +inline constexpr RetT extend_sub_sat(AT a, BT b, CT c, + BinaryOperation second_op) { + return detail::extend_binary(a, b, c, std::minus(), second_op); +} + +/// Extend \p a and \p b to 33 bit and do abs_diff. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The extend abs_diff of the two values +template +inline constexpr RetT extend_absdiff(AT a, BT b) { + return detail::extend_binary(a, b, abs_diff()); +} + +/// Extend Inputs to 33 bit, abs_diff \p a, \p b, then do \p second_op with \p +/// c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The extend abs_diff of \p a, \p b and \p second_op with \p c +template +inline constexpr RetT extend_absdiff(AT a, BT b, CT c, + BinaryOperation second_op) { + return detail::extend_binary(a, b, c, abs_diff(), second_op); +} + +/// Extend \p a and \p b to 33 bit and do abs_diff with saturation. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The extend abs_diff of the two values with saturation +template +inline constexpr RetT extend_absdiff_sat(AT a, BT b) { + return detail::extend_binary(a, b, abs_diff()); +} + +/// Extend Inputs to 33 bit, abs_diff \p a, \p b with saturation, then do \p +/// second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The extend abs_diff of \p a, \p b with saturation and \p +/// second_op with \p c +template +inline constexpr RetT extend_absdiff_sat(AT a, BT b, CT c, + BinaryOperation second_op) { + return detail::extend_binary(a, b, c, abs_diff(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return smaller one. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The smaller one of the two extended values +template +inline constexpr RetT extend_min(AT a, BT b) { + return detail::extend_binary(a, b, minimum()); +} + +/// Extend Inputs to 33 bit, find the smaller one in \p a, \p b, then do \p +/// second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The smaller one of \p a, \p b and \p second_op with \p c +template +inline constexpr RetT extend_min(AT a, BT b, CT c, BinaryOperation second_op) { + return detail::extend_binary(a, b, c, minimum(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return smaller one with saturation. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The smaller one of the two extended values with saturation +template +inline constexpr RetT extend_min_sat(AT a, BT b) { + return detail::extend_binary(a, b, minimum()); +} + +/// Extend Inputs to 33 bit, find the smaller one in \p a, \p b with saturation, +/// then do \p second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The smaller one of \p a, \p b with saturation and \p +/// second_op with \p c +template +inline constexpr RetT extend_min_sat(AT a, BT b, CT c, + BinaryOperation second_op) { + return detail::extend_binary(a, b, c, minimum(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return bigger one. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The bigger one of the two extended values +template +inline constexpr RetT extend_max(AT a, BT b) { + return detail::extend_binary(a, b, maximum()); +} + +/// Extend Inputs to 33 bit, find the bigger one in \p a, \p b, then do \p +/// second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The bigger one of \p a, \p b and \p second_op with \p c +template +inline constexpr RetT extend_max(AT a, BT b, CT c, BinaryOperation second_op) { + return detail::extend_binary(a, b, c, maximum(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return bigger one with saturation. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \param [in] a The first value +/// \param [in] b The second value +/// \returns The bigger one of the two extended values with saturation +template +inline constexpr RetT extend_max_sat(AT a, BT b) { + return detail::extend_binary(a, b, maximum()); +} + +/// Extend Inputs to 33 bit, find the bigger one in \p a, \p b with saturation, +/// then do \p second_op with \p c. +/// \tparam [in] RetT The type of the return value +/// \tparam [in] AT The type of the first value +/// \tparam [in] BT The type of the second value +/// \tparam [in] CT The type of the third value +/// \tparam [in] BinaryOperation The type of the second operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] second_op The operation to do with the third value +/// \returns The bigger one of \p a, \p b with saturation and \p +/// second_op with \p c +template +inline constexpr RetT extend_max_sat(AT a, BT b, CT c, + BinaryOperation second_op) { + return detail::extend_binary(a, b, c, maximum(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return a << clamp(b, 0, 32). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \returns a << clamp(b, 0, 32) +template +inline constexpr RetT extend_shl_clamp(T a, uint32_t b) { + return detail::extend_binary(a, sycl::clamp(b, 0u, 32u), + detail::shift_left()); +} + +/// Extend Inputs to 33 bit, and return second_op(a << clamp(b, 0, 32), c). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \param [in] c The value to merge +/// \param [in] second_op The operation to do with the third value +/// \returns second_op(a << clamp(b, 0, 32), c) +template +inline constexpr RetT extend_shl_clamp(T a, uint32_t b, uint32_t c, + BinaryOperation second_op) { + return detail::extend_binary(a, sycl::clamp(b, 0u, 32u), c, + detail::shift_left(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return sat(a << clamp(b, 0, 32)). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \returns sat(a << clamp(b, 0, 32)) +template +inline constexpr RetT extend_shl_sat_clamp(T a, uint32_t b) { + return detail::extend_binary(a, sycl::clamp(b, 0u, 32u), + detail::shift_left()); +} + +/// Extend Inputs to 33 bit, and return second_op(sat(a << clamp(b, 0, 32)), c). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \param [in] c The value to merge +/// \param [in] second_op The operation to do with the third value +/// \returns second_op(sat(a << clamp(b, 0, 32)), c) +template +inline constexpr RetT extend_shl_sat_clamp(T a, uint32_t b, uint32_t c, + BinaryOperation second_op) { + return detail::extend_binary(a, sycl::clamp(b, 0u, 32u), c, + detail::shift_left(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return a << (b & 0x1F). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \returns a << (b & 0x1F) +template +inline constexpr RetT extend_shl_wrap(T a, uint32_t b) { + return detail::extend_binary(a, b & 0x1F, detail::shift_left()); +} + +/// Extend Inputs to 33 bit, and return second_op(a << (b & 0x1F), c). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \param [in] c The value to merge +/// \param [in] second_op The operation to do with the third value +/// \returns second_op(a << (b & 0x1F), c) +template +inline constexpr RetT extend_shl_wrap(T a, uint32_t b, uint32_t c, + BinaryOperation second_op) { + return detail::extend_binary(a, b & 0x1F, c, + detail::shift_left(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return sat(a << (b & 0x1F)). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \returns sat(a << (b & 0x1F)) +template +inline constexpr RetT extend_shl_sat_wrap(T a, uint32_t b) { + return detail::extend_binary(a, b & 0x1F, detail::shift_left()); +} + +/// Extend Inputs to 33 bit, and return second_op(sat(a << (b & 0x1F)), c). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \param [in] c The value to merge +/// \param [in] second_op The operation to do with the third value +/// \returns second_op(sat(a << (b & 0x1F)), c) +template +inline constexpr RetT extend_shl_sat_wrap(T a, uint32_t b, uint32_t c, + BinaryOperation second_op) { + return detail::extend_binary(a, b & 0x1F, c, detail::shift_left(), + second_op); +} + +/// Extend \p a and \p b to 33 bit and return a >> clamp(b, 0, 32). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \returns a >> clamp(b, 0, 32) +template +inline constexpr RetT extend_shr_clamp(T a, uint32_t b) { + return detail::extend_binary(a, sycl::clamp(b, 0u, 32u), + detail::shift_right()); +} + +/// Extend Inputs to 33 bit, and return second_op(a >> clamp(b, 0, 32), c). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \param [in] c The value to merge +/// \param [in] second_op The operation to do with the third value +/// \returns second_op(a >> clamp(b, 0, 32), c) +template +inline constexpr RetT extend_shr_clamp(T a, uint32_t b, uint32_t c, + BinaryOperation second_op) { + return detail::extend_binary(a, sycl::clamp(b, 0u, 32u), c, + detail::shift_right(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return sat(a >> clamp(b, 0, 32)). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \returns sat(a >> clamp(b, 0, 32)) +template +inline constexpr RetT extend_shr_sat_clamp(T a, uint32_t b) { + return detail::extend_binary(a, sycl::clamp(b, 0u, 32u), + detail::shift_right()); +} + +/// Extend Inputs to 33 bit, and return second_op(sat(a >> clamp(b, 0, 32)), c). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \param [in] c The value to merge +/// \param [in] second_op The operation to do with the third value +/// \returns second_op(sat(a >> clamp(b, 0, 32)), c) +template +inline constexpr RetT extend_shr_sat_clamp(T a, uint32_t b, uint32_t c, + BinaryOperation second_op) { + return detail::extend_binary(a, sycl::clamp(b, 0u, 32u), c, + detail::shift_right(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return a >> (b & 0x1F). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \returns a >> (b & 0x1F) +template +inline constexpr RetT extend_shr_wrap(T a, uint32_t b) { + return detail::extend_binary(a, b & 0x1F, detail::shift_right()); +} + +/// Extend Inputs to 33 bit, and return second_op(a >> (b & 0x1F), c). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \param [in] c The value to merge +/// \param [in] second_op The operation to do with the third value +/// \returns second_op(a >> (b & 0x1F), c) +template +inline constexpr RetT extend_shr_wrap(T a, uint32_t b, uint32_t c, + BinaryOperation second_op) { + return detail::extend_binary(a, b & 0x1F, c, + detail::shift_right(), second_op); +} + +/// Extend \p a and \p b to 33 bit and return sat(a >> (b & 0x1F)). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \returns sat(a >> (b & 0x1F)) +template +inline constexpr RetT extend_shr_sat_wrap(T a, uint32_t b) { + return detail::extend_binary(a, b & 0x1F, detail::shift_right()); +} + +/// Extend Inputs to 33 bit, and return second_op(sat(a >> (b & 0x1F)), c). +/// \param [in] a The source value +/// \param [in] b The offset to shift +/// \param [in] c The value to merge +/// \param [in] second_op The operation to do with the third value +/// \returns second_op(sat(a >> (b & 0x1F)), c) +template +inline constexpr RetT extend_shr_sat_wrap(T a, uint32_t b, uint32_t c, + BinaryOperation second_op) { + return detail::extend_binary(a, b & 0x1F, c, + detail::shift_right(), second_op); +} + +/// Compute vectorized addition of \p a and \p b, with each value treated as a +/// 2 elements vector type and extend each element to 17 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized addition of the two values +template +inline constexpr RetT extend_vadd2(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, std::plus()); +} + +/// Compute vectorized addition of \p a and \p b, with each value treated as a 2 +/// elements vector type and extend each element to 17 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized addition of the two +/// values and the third value +template +inline constexpr RetT extend_vadd2_add(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, std::plus()); +} + +/// Compute vectorized addition of \p a and \p b with saturation, with each +/// value treated as a 2 elements vector type and extend each element to 17 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized addition of the two values with saturation +template +inline constexpr RetT extend_vadd2_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, std::plus()); +} + +/// Compute vectorized subtraction of \p a and \p b, with each value treated as +/// a 2 elements vector type and extend each element to 17 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized subtraction of the two values +template +inline constexpr RetT extend_vsub2(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, std::minus()); +} + +/// Compute vectorized subtraction of \p a and \p b, with each value treated as +/// a 2 elements vector type and extend each element to 17 bit. Then add each +/// half of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized subtraction of the +/// two values and the third value +template +inline constexpr RetT extend_vsub2_add(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, std::minus()); +} + +/// Compute vectorized subtraction of \p a and \p b with saturation, with each +/// value treated as a 2 elements vector type and extend each element to 17 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized subtraction of the two values with saturation +template +inline constexpr RetT extend_vsub2_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, std::minus()); +} + +/// Compute vectorized abs_diff of \p a and \p b, with each value treated as a 2 +/// elements vector type and extend each element to 17 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized abs_diff of the two values +template +inline constexpr RetT extend_vabsdiff2(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, abs_diff()); +} + +/// Compute vectorized abs_diff of \p a and \p b, with each value treated as a 2 +/// elements vector type and extend each element to 17 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized abs_diff of the +/// two values and the third value +template +inline constexpr RetT extend_vabsdiff2_add(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, abs_diff()); +} + +/// Compute vectorized abs_diff of \p a and \p b with saturation, with each +/// value treated as a 2 elements vector type and extend each element to 17 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized abs_diff of the two values with saturation +template +inline constexpr RetT extend_vabsdiff2_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, abs_diff()); +} + +/// Compute vectorized minimum of \p a and \p b, with each value treated as a 2 +/// elements vector type and extend each element to 17 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized minimum of the two values +template +inline constexpr RetT extend_vmin2(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, minimum()); +} + +/// Compute vectorized minimum of \p a and \p b, with each value treated as a 2 +/// elements vector type and extend each element to 17 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized minimum of the +/// two values and the third value +template +inline constexpr RetT extend_vmin2_add(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, minimum()); +} + +/// Compute vectorized minimum of \p a and \p b with saturation, with each value +/// treated as a 2 elements vector type and extend each element to 17 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized minimum of the two values with saturation +template +inline constexpr RetT extend_vmin2_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, minimum()); +} + +/// Compute vectorized maximum of \p a and \p b, with each value treated as a 2 +/// elements vector type and extend each element to 17 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized maximum of the two values +template +inline constexpr RetT extend_vmax2(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, maximum()); +} + +/// Compute vectorized maximum of \p a and \p b, with each value treated as a 2 +/// elements vector type and extend each element to 17 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized maximum of the +/// two values and the third value +template +inline constexpr RetT extend_vmax2_add(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, maximum()); +} + +/// Compute vectorized maximum of \p a and \p b with saturation, with each value +/// treated as a 2 elements vector type and extend each element to 17 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized maximum of the two values with saturation +template +inline constexpr RetT extend_vmax2_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, maximum()); +} + +/// Compute vectorized average of \p a and \p b, with each value treated as a 2 +/// elements vector type and extend each element to 17 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized average of the two values +template +inline constexpr RetT extend_vavrg2(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, + detail::average()); +} + +/// Compute vectorized average of \p a and \p b, with each value treated as a 2 +/// elements vector type and extend each element to 17 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend average maximum of the +/// two values and the third value +template +inline constexpr RetT extend_vavrg2_add(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, detail::average()); +} + +/// Compute vectorized average of \p a and \p b with saturation, with each value +/// treated as a 2 elements vector type and extend each element to 17 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized average of the two values with saturation +template +inline constexpr RetT extend_vavrg2_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary2(a, b, c, detail::average()); +} + +/// Extend \p a and \p b to 33 bit and vectorized compare input values using +/// specified comparison \p cmp . +/// +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \tparam [in] BinaryOperation The type of the compare operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] cmp The comparsion operator +/// \returns The comparison result of the two extended values. +template +inline constexpr unsigned extend_vcompare2(AT a, BT b, BinaryOperation cmp) { + return detail::extend_vbinary2(a, b, 0, cmp); +} + +/// Extend Inputs to 33 bit, and vectorized compare input values using specified +/// comparison \p cmp , then add the result with \p c . +/// +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \tparam [in] BinaryOperation The type of the compare operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] cmp The comparsion operator +/// \returns The comparison result of the two extended values, and add the +/// result with \p c . +template +inline constexpr unsigned extend_vcompare2_add(AT a, BT b, unsigned c, + BinaryOperation cmp) { + return detail::extend_vbinary2(a, b, c, cmp); +} + +/// Compute vectorized addition of \p a and \p b, with each value treated as a +/// 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized addition of the two values +template +inline constexpr RetT extend_vadd4(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, std::plus()); +} + +/// Compute vectorized addition of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized addition of the two +/// values and the third value +template +inline constexpr RetT extend_vadd4_add(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, std::plus()); +} + +/// Compute vectorized addition of \p a and \p b with saturation, with each +/// value treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized addition of the two values with saturation +template +inline constexpr RetT extend_vadd4_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, std::plus()); +} + +/// Compute vectorized subtraction of \p a and \p b, with each value treated as +/// a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized subtraction of the two values +template +inline constexpr RetT extend_vsub4(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, std::minus()); +} + +/// Compute vectorized subtraction of \p a and \p b, with each value treated as +/// a 4 elements vector type and extend each element to 9 bit. Then add each +/// half of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized subtraction of the +/// two values and the third value +template +inline constexpr RetT extend_vsub4_add(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, std::minus()); +} + +/// Compute vectorized subtraction of \p a and \p b with saturation, with each +/// value treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized subtraction of the two values with saturation +template +inline constexpr RetT extend_vsub4_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, std::minus()); +} + +/// Compute vectorized abs_diff of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized abs_diff of the two values +template +inline constexpr RetT extend_vabsdiff4(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, abs_diff()); +} + +/// Compute vectorized abs_diff of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized abs_diff of the +/// two values and the third value +template +inline constexpr RetT extend_vabsdiff4_add(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, abs_diff()); +} + +/// Compute vectorized abs_diff of \p a and \p b with saturation, with each +/// value treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized abs_diff of the two values with saturation +template +inline constexpr RetT extend_vabsdiff4_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, abs_diff()); +} + +/// Compute vectorized minimum of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized minimum of the two values +template +inline constexpr RetT extend_vmin4(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, minimum()); +} + +/// Compute vectorized minimum of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized minimum of the +/// two values and the third value +template +inline constexpr RetT extend_vmin4_add(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, minimum()); +} + +/// Compute vectorized minimum of \p a and \p b with saturation, with each value +/// treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized minimum of the two values with saturation +template +inline constexpr RetT extend_vmin4_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, minimum()); +} + +/// Compute vectorized maximum of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized maximum of the two values +template +inline constexpr RetT extend_vmax4(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, maximum()); +} + +/// Compute vectorized maximum of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized maximum of the +/// two values and the third value +template +inline constexpr RetT extend_vmax4_add(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, maximum()); +} + +/// Compute vectorized maximum of \p a and \p b with saturation, with each value +/// treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized maximum of the two values with saturation +template +inline constexpr RetT extend_vmax4_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, maximum()); +} + +/// Compute vectorized average of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized average of the two values +template +inline constexpr RetT extend_vavrg4(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, + detail::average()); +} + +/// Compute vectorized average of \p a and \p b, with each value treated as a 4 +/// elements vector type and extend each element to 9 bit. Then add each half +/// of the result and add with \p c. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The addition of each half of extend vectorized average of the +/// two values and the third value +template +inline constexpr RetT extend_vavrg4_add(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, detail::average()); +} + +/// Compute vectorized average of \p a and \p b with saturation, with each value +/// treated as a 4 elements vector type and extend each element to 9 bit. +/// \tparam [in] RetT The type of the return value, can only be 32 bit integer +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \returns The extend vectorized average of the two values with saturation +template +inline constexpr RetT extend_vavrg4_sat(AT a, BT b, RetT c) { + return detail::extend_vbinary4(a, b, c, detail::average()); +} + +/// Extend \p a and \p b to 33 bit and vectorized compare input values using +/// specified comparison \p cmp . +/// +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \tparam [in] BinaryOperation The type of the compare operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] cmp The comparsion operator +/// \returns The comparison result of the two extended values. +template +inline constexpr unsigned extend_vcompare4(AT a, BT b, BinaryOperation cmp) { + return detail::extend_vbinary4(a, b, 0, cmp); +} + +/// Extend Inputs to 33 bit, and vectorized compare input values using specified +/// comparison \p cmp , then add the result with \p c . +/// +/// \tparam [in] AT The type of the first value, can only be 32 bit integer +/// \tparam [in] BT The type of the second value, can only be 32 bit integer +/// \tparam [in] BinaryOperation The type of the compare operation +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] cmp The comparsion operator +/// \returns The comparison result of the two extended values, and add the +/// result with \p c . +template +inline constexpr unsigned extend_vcompare4_add(AT a, BT b, unsigned c, + BinaryOperation cmp) { + return detail::extend_vbinary4(a, b, c, cmp); +} + +} // namespace compat diff --git a/tools/util/include/compat/memory.hpp b/tools/util/include/compat/memory.hpp new file mode 100644 index 0000000000..d50e1ed92c --- /dev/null +++ b/tools/util/include/compat/memory.hpp @@ -0,0 +1,1762 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCL compatibility extension + * + * memory.hpp + * + * Description: + * memory functionality for the SYCL compatibility extension + **************************************************************************/ + +// The original source was under the license below: +//==---- memory.hpp -------------------------------*- C++ -*----------------==// +// +// Copyright (C) Intel Corporation +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY +#include +#endif + +#include +#include +#include + +#if defined(__linux__) +#include +#elif defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#else +#error "Only support Windows and Linux." +#endif + +namespace compat { + +template +#ifdef __SYCL_DEVICE_ONLY__ +[[__sycl_detail__::add_ir_attributes_function("sycl-forceinline", true)]] +#endif +__SYCL_ALWAYS_INLINE auto *local_mem() { + sycl::multi_ptr + As_multi_ptr = + sycl::ext::oneapi::group_local_memory_for_overwrite( + sycl::ext::oneapi::this_work_item::get_work_group<3>()); + auto *As = *As_multi_ptr; + return As; +} + +namespace detail { +enum memcpy_direction { + host_to_host, + host_to_device, + device_to_host, + device_to_device, + automatic +}; +} // namespace detail + +template +__compat_inline__ + std::enable_if_t || std::is_same_v, + T> + ptr_to_int(void *ptr) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + if constexpr (std::is_same_v) { + return (intptr_t)(sycl::decorated_local_ptr::pointer)ptr; + } else { + return (size_t)(sycl::decorated_local_ptr::pointer)ptr; + } +#else + throw sycl::exception(make_error_code(sycl::errc::runtime), + "ptr_to_int is only supported on Nvidia devices."); +#endif +} + +enum class memory_region { + global = 0, // device global memory + constant, // device read-only memory + local, // device local memory + usm_shared, // memory which can be accessed by host and device +}; + +using byte_t = uint8_t; + +/// Buffer type to be used in Memory Management runtime. +typedef sycl::buffer buffer_t; + +/// Pitched 2D/3D memory data. +class pitched_data { +public: + pitched_data() : pitched_data(nullptr, 0, 0, 0) {} + pitched_data(void *data, size_t pitch, size_t x, size_t y) + : _data(data), _pitch(pitch), _x(x), _y(y) {} + + void *get_data_ptr() { return _data; } + void set_data_ptr(void *data) { _data = data; } + + size_t get_pitch() { return _pitch; } + void set_pitch(size_t pitch) { _pitch = pitch; } + + size_t get_x() { return _x; } + void set_x(size_t x) { _x = x; }; + + size_t get_y() { return _y; } + void set_y(size_t y) { _y = y; } + +private: + void *_data; + size_t _pitch, _x, _y; +}; + +namespace experimental { +#ifdef SYCL_EXT_ONEAPI_BINDLESS_IMAGES +class image_mem_wrapper; +namespace detail { +static sycl::event memcpy(const image_mem_wrapper *src, + const sycl::id<3> &src_id, pitched_data &dest, + const sycl::id<3> &dest_id, + const sycl::range<3> ©_extend, sycl::queue q); +static sycl::event memcpy(const pitched_data src, const sycl::id<3> &src_id, + image_mem_wrapper *dest, const sycl::id<3> &dest_id, + const sycl::range<3> ©_extend, sycl::queue q); +} // namespace detail +#endif +class image_matrix; +namespace detail { +static pitched_data to_pitched_data(image_matrix *image); +} + +/// Memory copy parameters for 2D/3D memory data. +struct memcpy_parameter { + struct data_wrapper { + pitched_data pitched{}; + sycl::id<3> pos{}; +#ifdef SYCL_EXT_ONEAPI_BINDLESS_IMAGES + experimental::image_mem_wrapper *image_bindless{nullptr}; +#endif + image_matrix *image{nullptr}; + }; + data_wrapper from{}; + data_wrapper to{}; + sycl::range<3> size{}; +}; +} // namespace experimental + +namespace detail { +class mem_mgr { + mem_mgr() { + // Reserved address space, no real memory allocation happens here. +#if defined(__linux__) + mapped_address_space = + (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); +#elif defined(_WIN64) + mapped_address_space = (byte_t *)VirtualAlloc( + NULL, // NULL specified as the base address parameter + mapped_region_size, // Size of allocation + MEM_RESERVE, // Allocate reserved pages + PAGE_NOACCESS); // Protection = no access +#else +#error "Only support Windows and Linux." +#endif + next_free = mapped_address_space; + }; + +public: + using buffer_id_t = int; + + struct allocation { + buffer_t buffer; + byte_t *alloc_ptr; + size_t size; + }; + + ~mem_mgr() { +#if defined(__linux__) + munmap(mapped_address_space, mapped_region_size); +#elif defined(_WIN64) + VirtualFree(mapped_address_space, 0, MEM_RELEASE); +#else +#error "Only support Windows and Linux." +#endif + }; + + mem_mgr(const mem_mgr &) = delete; + mem_mgr &operator=(const mem_mgr &) = delete; + mem_mgr(mem_mgr &&) = delete; + mem_mgr &operator=(mem_mgr &&) = delete; + + /// Allocate + void *mem_alloc(size_t size) { + if (!size) + return nullptr; + std::lock_guard lock(m_mutex); + if (next_free + size > mapped_address_space + mapped_region_size) { + throw std::runtime_error( + "[Compat] malloc: out of memory for virtual memory pool"); + } + // Allocation + sycl::range<1> buffer_range(size); + buffer_t buf(buffer_range); + allocation alloc{buf, next_free, size}; + // Map allocation to device pointer + void *result = next_free; + m_map.emplace(next_free + size, alloc); + // Update pointer to the next free space. + next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1); + + return result; + } + + /// Deallocate + void mem_free(const void *ptr) { + if (!ptr) + return; + std::lock_guard lock(m_mutex); + auto it = get_map_iterator(ptr); + m_map.erase(it); + } + + /// map: device pointer -> allocation(buffer, alloc_ptr, size) + allocation translate_ptr(const void *ptr) { + std::lock_guard lock(m_mutex); + auto it = get_map_iterator(ptr); + return it->second; + } + + /// Check if the pointer represents device pointer or not. + bool is_device_ptr(const void *ptr) const { + std::lock_guard lock(m_mutex); + return (mapped_address_space <= ptr) && + (ptr < mapped_address_space + mapped_region_size); + } + + /// Returns the instance of memory manager singleton. + static mem_mgr &instance() { + static mem_mgr m; + return m; + } + +private: + std::map m_map; + mutable std::mutex m_mutex; + byte_t *mapped_address_space; + byte_t *next_free; + const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024; + const size_t alignment = 256; + /// This padding may be defined to some positive value to debug + /// out of bound accesses. + const size_t extra_padding = 0; + + std::map::iterator get_map_iterator(const void *ptr) { + auto it = m_map.upper_bound((byte_t *)ptr); + if (it == m_map.end()) { + // Not a virtual pointer. + throw std::runtime_error("[Compat] can not get buffer from non-virtual pointer"); + } + const allocation &alloc = it->second; + if (ptr < alloc.alloc_ptr) { + // Out of bound. + // This may happen if there's a gap between allocations due to alignment + // or extra padding and pointer points to this gap. + throw std::runtime_error("[Compat] invalid virtual pointer"); + } + return it; + } +}; + +template class accessor; +template class memory_traits { +public: + static constexpr sycl::access::address_space asp = + (Memory == memory_region::local) + ? sycl::access::address_space::local_space + : sycl::access::address_space::global_space; + static constexpr sycl::target target = (Memory == memory_region::local) + ? sycl::target::local + : sycl::target::device; + static constexpr sycl::access_mode mode = (Memory == memory_region::constant) + ? sycl::access_mode::read + : sycl::access_mode::read_write; + static constexpr size_t type_size = sizeof(T); + using element_t = + typename std::conditional_t; + using value_t = typename std::remove_cv_t; + template + using accessor_t = + typename std::conditional_t, + sycl::accessor>; + using pointer_t = + typename std::conditional_t; +}; + +static inline void *malloc(size_t size, sycl::queue q) { +#ifdef COMPAT_USM_LEVEL_NONE + return mem_mgr::instance().mem_alloc(size * sizeof(byte_t)); +#else + return sycl::malloc_device(size, q.get_device(), q.get_context()); +#endif // COMPAT_USM_LEVEL_NONE +} + +/// Calculate pitch (padded length of major dimension \p x) by rounding up to +/// multiple of 32. +/// \param x The dimension to be padded (in bytes) +/// \returns size_t representing pitched length of dimension x (in bytes). +static inline constexpr size_t get_pitch(size_t x) { + return ((x) + 31) & ~(0x1F); +} + +/// \brief Malloc pitched 3D data +/// \param [out] pitch returns the calculated pitch (in bytes) +/// \param [in] x width of the allocation (in bytes) +/// \param [in] y height of the allocation +/// \param [in] z depth of the allocation +/// \param [in] q The queue in which the operation is done. +/// \returns A pointer to the allocated memory +static inline void *malloc(size_t &pitch, size_t x, size_t y, size_t z, + sycl::queue q) { + pitch = get_pitch(x); + return malloc(pitch * y * z, q); +} + +/// \brief Set \p pattern to the first \p count elements of type \p T +/// starting from \p dev_ptr. +/// +/// \tparam T Datatype of the pattern to be set. +/// \param q The queue in which the operation is done. +/// \param dev_ptr Pointer to the device memory address. +/// \param pattern Pattern of type T to be set. +/// \param count Number of elements to be set to the patten. +/// \returns An event representing the fill operation. +template +static inline sycl::event fill(sycl::queue q, void *dev_ptr, const T &pattern, + size_t count) { +#ifdef COMPAT_USM_LEVEL_NONE + auto &mm = mem_mgr::instance(); + assert(mm.is_device_ptr(dev_ptr)); + auto alloc = mm.translate_ptr(dev_ptr); + size_t offset = (T *)dev_ptr - (T *)alloc.alloc_ptr; + + return q.submit([&](sycl::handler &cgh) { + auto r = sycl::range<1>(count); + auto o = sycl::id<1>(offset); + auto new_buffer = + alloc.buffer.reinterpret(sycl::range<1>(alloc.size / sizeof(T))); + sycl::accessor + acc(new_buffer, cgh, r, o); + cgh.fill(acc, pattern); + }); +#else + return q.fill(dev_ptr, pattern, count); +#endif +} + +/// Set \p value to the first \p size bytes starting from \p dev_ptr in \p q. +/// +/// \param q The queue in which the operation is done. +/// \param dev_ptr Pointer to the device memory address. +/// \param value Value to be set. +/// \param size Number of bytes to be set to the value. +/// \returns An event representing the memset operation. +static inline sycl::event memset(sycl::queue q, void *dev_ptr, int value, + size_t size) { +#ifdef COMPAT_USM_LEVEL_NONE + auto &mm = mem_mgr::instance(); + assert(mm.is_device_ptr(dev_ptr)); + auto alloc = mm.translate_ptr(dev_ptr); + size_t offset = (byte_t *)dev_ptr - (byte_t *)alloc.alloc_ptr; + + return q.submit([&](sycl::handler &cgh) { + auto r = sycl::range<1>(size); + auto o = sycl::id<1>(offset); + auto new_buffer = alloc.buffer.reinterpret( + sycl::range<1>(alloc.size / sizeof(byte_t))); + sycl::accessor + acc(new_buffer, cgh, r, o); + cgh.fill(acc, static_cast(value)); + }); +#else + return q.memset(dev_ptr, value, size); +#endif // COMPAT_USM_LEVEL_NONE +} + +/// \brief Sets \p value to the 3D memory region pointed by \p data in \p q. +/// \tparam T The type of the element to be set. +/// \param [in] q The queue in which the operation is done. +/// \param [in] data Pointer to the pitched device memory region. +/// \param [in] value The value to be set. +/// \param [in] size 3D memory region by number of elements. +/// \return An event list representing the memset operations. +template +static inline std::vector +memset(sycl::queue q, pitched_data data, const T &value, sycl::range<3> size) { + std::vector event_list; + size_t slice = data.get_pitch() * data.get_y(); + unsigned char *data_surface = (unsigned char *)data.get_data_ptr(); + for (size_t z = 0; z < size.get(2); ++z) { + unsigned char *data_ptr = data_surface; + for (size_t y = 0; y < size.get(1); ++y) { + event_list.push_back(detail::fill(q, data_ptr, value, size.get(0))); + data_ptr += data.get_pitch(); + } + data_surface += slice; + } + return event_list; +} + +/// \brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p +/// q. +/// \tparam T The type of the element to be set. +/// \param [in] q The queue in which the operation is done. +/// \param [in] ptr Pointer to the virtual device memory. +/// \param [in] pitch The pitch size by number of elements, including padding. +/// \param [in] value The value to be set. +/// \param [in] x The width of memory region by number of elements. +/// \param [in] y The height of memory region by number of elements. +/// \return An event list representing the memset operations. +template +static inline std::vector memset(sycl::queue q, void *ptr, + size_t pitch, const T &value, + size_t x, size_t y) { + return memset(q, pitched_data(ptr, pitch, x, 1), value, + sycl::range<3>(x, y, 1)); +} + +enum class pointer_access_attribute { + host_only = 0, + device_only, + host_device, + end +}; + +static pointer_access_attribute get_pointer_attribute(sycl::queue q, + const void *ptr) { +#ifdef COMPAT_USM_LEVEL_NONE + return mem_mgr::instance().is_device_ptr(ptr) + ? pointer_access_attribute::device_only + : pointer_access_attribute::host_only; +#else + switch (sycl::get_pointer_type(ptr, q.get_context())) { + case sycl::usm::alloc::unknown: + return pointer_access_attribute::host_only; + case sycl::usm::alloc::device: + return pointer_access_attribute::device_only; + case sycl::usm::alloc::shared: + case sycl::usm::alloc::host: + return pointer_access_attribute::host_device; + } +#endif // COMPAT_USM_LEVEL_NONE +} + +static memcpy_direction +deduce_memcpy_direction(sycl::queue q, void *to_ptr, const void *from_ptr) { + // table[to_attribute][from_attribute] + static const memcpy_direction + direction_table[static_cast(pointer_access_attribute::end)] + [static_cast(pointer_access_attribute::end)] = { + {host_to_host, device_to_host, host_to_host}, + {host_to_device, device_to_device, device_to_device}, + {host_to_host, device_to_device, device_to_device}}; + return direction_table[static_cast(get_pointer_attribute( + q, to_ptr))][static_cast(get_pointer_attribute(q, from_ptr))]; +} + +static sycl::event memcpy(sycl::queue q, void *to_ptr, const void *from_ptr, + size_t size, + const std::vector &dep_events = {}) { + if (!size) + return sycl::event{}; +#ifdef COMPAT_USM_LEVEL_NONE + auto &mm = mem_mgr::instance(); + auto real_direction = deduce_memcpy_direction(q, to_ptr, from_ptr); + + switch (real_direction) { + case host_to_host: + return q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_events); + cgh.host_task([=] { std::memcpy(to_ptr, from_ptr, size); }); + }); + case host_to_device: { + auto alloc = mm.translate_ptr(to_ptr); + size_t offset = (byte_t *)to_ptr - alloc.alloc_ptr; + return q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_events); + auto r = sycl::range<1>(size); + auto o = sycl::id<1>(offset); + sycl::accessor + acc(alloc.buffer, cgh, r, o); + cgh.copy(from_ptr, acc); + }); + } + case device_to_host: { + auto alloc = mm.translate_ptr(from_ptr); + size_t offset = (byte_t *)from_ptr - alloc.alloc_ptr; + return q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_events); + auto r = sycl::range<1>(size); + auto o = sycl::id<1>(offset); + sycl::accessor + acc(alloc.buffer, cgh, r, o); + cgh.copy(acc, to_ptr); + }); + } + case device_to_device: { + auto to_alloc = mm.translate_ptr(to_ptr); + auto from_alloc = mm.translate_ptr(from_ptr); + size_t to_offset = (byte_t *)to_ptr - to_alloc.alloc_ptr; + size_t from_offset = (byte_t *)from_ptr - from_alloc.alloc_ptr; + return q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_events); + auto r = sycl::range<1>(size); + auto to_o = sycl::id<1>(to_offset); + auto from_o = sycl::id<1>(from_offset); + sycl::accessor + to_acc(to_alloc.buffer, cgh, r, to_o); + sycl::accessor + from_acc(from_alloc.buffer, cgh, r, from_o); + cgh.copy(from_acc, to_acc); + }); + } + default: + throw std::runtime_error("[Compat] memcpy: invalid direction value"); + } +#else + return q.memcpy(to_ptr, from_ptr, size, dep_events); +#endif // COMPAT_USM_LEVEL_NONE +} + +// Get actual copy range and make sure it will not exceed range. +static inline size_t get_copy_range(sycl::range<3> size, size_t slice, + size_t pitch) { + return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); +} + +static inline size_t get_offset(sycl::id<3> id, size_t slice, size_t pitch) { + return slice * id.get(2) + pitch * id.get(1) + id.get(0); +} + +// RAII for host pointer +class host_buffer { + void *_buf; + size_t _size; + sycl::queue _q; + const std::vector &_deps; // free operation depends + +public: + host_buffer(size_t size, sycl::queue q, const std::vector &deps) + : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} + void *get_ptr() const { return _buf; } + size_t get_size() const { return _size; } + ~host_buffer() { + if (_buf) { + _q.submit([&](sycl::handler &cgh) { + cgh.depends_on(_deps); + cgh.host_task([buf = _buf] { std::free(buf); }); + }); + } + } +}; + +/// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr +/// and \p from_range to another specified by \p to_ptr and \p to_range. +static inline std::vector +memcpy(sycl::queue q, void *to_ptr, const void *from_ptr, + sycl::range<3> to_range, sycl::range<3> from_range, sycl::id<3> to_id, + sycl::id<3> from_id, sycl::range<3> size, + const std::vector &dep_events = {}) { + + std::vector event_list; + + size_t to_slice = to_range.get(1) * to_range.get(0); + size_t from_slice = from_range.get(1) * from_range.get(0); + unsigned char *to_surface = + (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); + const unsigned char *from_surface = + (const unsigned char *)from_ptr + + get_offset(from_id, from_slice, from_range.get(0)); + + if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) { + return {memcpy(q, to_surface, from_surface, to_slice * size.get(2), + dep_events)}; + } + using namespace experimental; // for memcpy_direction + memcpy_direction direction = deduce_memcpy_direction(q, to_ptr, from_ptr); + size_t size_slice = size.get(1) * size.get(0); + switch (direction) { + case host_to_host: + for (size_t z = 0; z < size.get(2); ++z) { + unsigned char *to_ptr = to_surface; + const unsigned char *from_ptr = from_surface; + if (to_range.get(0) == from_range.get(0) && + to_range.get(0) == size.get(0)) { + event_list.push_back( + memcpy(q, to_ptr, from_ptr, size_slice, dep_events)); + } else { + for (size_t y = 0; y < size.get(1); ++y) { + event_list.push_back( + memcpy(q, to_ptr, from_ptr, size.get(0), dep_events)); + to_ptr += to_range.get(0); + from_ptr += from_range.get(0); + } + } + to_surface += to_slice; + from_surface += from_slice; + } + break; + case host_to_device: { + host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, + event_list); + std::vector host_events; + if (to_slice == size_slice) { + // Copy host data to a temp host buffer with the shape of target. + host_events = + memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, dep_events); + } else { + // Copy host data to a temp host buffer with the shape of target. + host_events = + memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, + // If has padding data, not sure whether it is useless. So fill + // temp buffer with it. + std::vector{memcpy(q, buf.get_ptr(), to_surface, + buf.get_size(), dep_events)}); + } + // Copy from temp host buffer to device with only one submit. + event_list.push_back( + memcpy(q, to_surface, buf.get_ptr(), buf.get_size(), host_events)); + break; + } + case device_to_host: { + host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, + event_list); + // Copy from host temp buffer to host target with reshaping. + event_list = + memcpy(q, to_surface, buf.get_ptr(), to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, + // Copy from device to temp host buffer with only one submit. + std::vector{memcpy(q, buf.get_ptr(), from_surface, + buf.get_size(), dep_events)}); + break; + } + case device_to_device: +#ifdef COMPAT_USM_LEVEL_NONE + { + auto &mm = mem_mgr::instance(); + auto to_alloc = mm.translate_ptr(to_surface); + auto from_alloc = mm.translate_ptr(from_surface); + size_t to_offset = (byte_t *)to_surface - to_alloc.alloc_ptr; + size_t from_offset = (byte_t *)from_surface - from_alloc.alloc_ptr; + event_list.push_back(q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_events); + auto to_o = sycl::id<1>(to_offset); + auto from_o = sycl::id<1>(from_offset); + sycl::accessor + to_acc(to_alloc.buffer, cgh, + get_copy_range(size, to_slice, to_range.get(0)), to_o); + sycl::accessor + from_acc(from_alloc.buffer, cgh, + get_copy_range(size, from_slice, from_range.get(0)), from_o); + cgh.parallel_for( + size, [=](sycl::id<3> id) { + to_acc[get_offset(id, to_slice, to_range.get(0))] = + from_acc[get_offset(id, from_slice, from_range.get(0))]; + }); + })); + } +#else + event_list.push_back(q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_events); + cgh.parallel_for(size, [=](sycl::id<3> id) { + to_surface[get_offset(id, to_slice, to_range.get(0))] = + from_surface[get_offset(id, from_slice, from_range.get(0))]; + }); + })); +#endif // COMPAT_USM_LEVEL_NONE + break; + default: + throw std::runtime_error("[Compat] memcpy: invalid direction value"); + } + return event_list; +} + +/// memcpy 2D/3D matrix specified by pitched_data. +static inline std::vector +memcpy(sycl::queue q, pitched_data to, sycl::id<3> to_id, pitched_data from, + sycl::id<3> from_id, sycl::range<3> size) { + return memcpy(q, to.get_data_ptr(), from.get_data_ptr(), + sycl::range<3>(to.get_pitch(), to.get_y(), 1), + sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, + from_id, size); +} + +/// memcpy 2D matrix with pitch. +static inline std::vector +memcpy(sycl::queue q, void *to_ptr, const void *from_ptr, size_t to_pitch, + size_t from_pitch, size_t x, size_t y) { + return memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), + sycl::range<3>(from_pitch, y, 1), sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), sycl::range<3>(x, y, 1)); +} + +// Takes a std::vector & returns a single event +// which simply depends on all of them +static sycl::event combine_events(std::vector &events, + sycl::queue q) { + return q.submit([&events](sycl::handler &cgh) { + cgh.depends_on(events); + cgh.host_task([]() {}); + }); +} + +} // namespace detail + +#ifdef COMPAT_USM_LEVEL_NONE +/// Check if the pointer \p ptr represents device pointer or not. +/// +/// \param ptr The pointer to be checked. +/// \returns true if \p ptr is a device pointer. +template static inline bool is_device_ptr(T ptr) { + if constexpr (std::is_pointer::value) { + return detail::mem_mgr::instance().is_device_ptr(ptr); + } + return false; +} +#endif + +/// Get the buffer and the offset of a piece of memory pointed to by \p ptr. +/// +/// \param ptr Pointer to a piece of memory. +/// If NULL is passed as an argument, an exception will be thrown. +/// \returns a pair containing both the buffer and the offset. +static std::pair get_buffer_and_offset(const void *ptr) { + if (ptr) { + auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); + size_t offset = (byte_t *)ptr - alloc.alloc_ptr; + return std::make_pair(alloc.buffer, offset); + } else { + throw std::runtime_error( + "[Compat] NULL pointer argument in get_buffer_and_offset function is invalid"); + } +} + +/// Get the data pointed from \p ptr as a 1D buffer reinterpreted as type T. +template static sycl::buffer get_buffer(const void *ptr) { + if (!ptr) + return sycl::buffer(sycl::range<1>(0)); + auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); + return alloc.buffer.reinterpret(sycl::range<1>(alloc.size / sizeof(T))); +} + +/// Get the buffer of a piece of memory pointed to by \p ptr. +/// +/// \param ptr Pointer to a piece of memory. +/// \returns the buffer. +static buffer_t get_buffer(const void *ptr) { + return detail::mem_mgr::instance().translate_ptr(ptr).buffer; +} + +/// Get the host pointer from a buffer that is mapped to virtual pointer ptr. +/// \param ptr Virtual Pointer mapped to device buffer +/// \returns A host pointer +template static inline T *get_host_ptr(const void *ptr) { + auto BufferOffset = get_buffer_and_offset(ptr); + auto host_ptr = BufferOffset.first.get_host_access() + .get_multi_ptr(); + return (T *)(host_ptr + BufferOffset.second); +} + +/// A wrapper class contains an accessor and an offset. +template +class access_wrapper { + sycl::accessor accessor; + size_t offset; + +public: + /// Construct the accessor wrapper for memory pointed by \p ptr. + /// + /// \param ptr Pointer to memory. + /// \param cgh The command group handler. + access_wrapper(const void *ptr, sycl::handler &cgh) + : accessor(get_buffer(ptr).get_access(cgh)), offset(0) { + auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); + offset = (byte_t *)ptr - alloc.alloc_ptr; + } + + /// Get the device pointer. + /// + /// \returns a device pointer with offset. + dataT get_raw_pointer() const { return (dataT)(&accessor[0] + offset); } +}; + +/// Get the accessor for memory pointed by \p ptr. +/// +/// \param ptr Pointer to memory. +/// If NULL is passed as an argument, an exception will be thrown. +/// \param cgh The command group handler. +/// \returns an accessor. +template +static sycl::accessor get_access(const void *ptr, + sycl::handler &cgh) { + if (ptr) { + auto alloc = detail::mem_mgr::instance().translate_ptr(ptr); + return alloc.buffer.get_access(cgh); + } else { + throw std::runtime_error( + "[Compat] NULL pointer argument in get_access function is invalid"); + } +} + +namespace experimental { +namespace detail { +static inline std::vector +memcpy(sycl::queue q, const experimental::memcpy_parameter ¶m) { + auto to = param.to.pitched; + auto from = param.from.pitched; +#ifdef SYCL_EXT_ONEAPI_BINDLESS_IMAGES + if (param.to.image_bindless != nullptr && + param.from.image_bindless != nullptr) { + throw std::runtime_error( + "[Compat] memcpy: Unsupported bindless_image API."); + // TODO: Need change logic when sycl support image_mem to image_mem copy. + std::vector event_list; + compat::detail::host_buffer buf(param.size.size(), q, event_list); + to.set_data_ptr(buf.get_ptr()); + experimental::detail::memcpy(param.from.image_bindless, param.from.pos, to, + sycl::id<3>(0, 0, 0), param.size, q); + from.set_data_ptr(buf.get_ptr()); + event_list.push_back(experimental::detail::memcpy( + from, sycl::id<3>(0, 0, 0), param.to.image_bindless, param.to.pos, + param.size, q)); + return event_list; + } else if (param.to.image_bindless != nullptr) { + throw std::runtime_error( + "[Compat] memcpy: Unsupported bindless_image API."); + return {experimental::detail::memcpy(from, param.from.pos, + param.to.image_bindless, param.to.pos, + param.size, q)}; + } else if (param.from.image_bindless != nullptr) { + throw std::runtime_error( + "[Compat] memcpy: Unsupported bindless_image API."); + return {experimental::detail::memcpy(param.from.image_bindless, + param.from.pos, to, param.to.pos, + param.size, q)}; + } +#endif + if (param.to.image != nullptr) { + throw std::runtime_error("[Compat] memcpy: Unsupported image API."); + to = experimental::detail::to_pitched_data(param.to.image); + } + if (param.from.image != nullptr) { + throw std::runtime_error("[Compat] memcpy: Unsupported image API."); + from = experimental::detail::to_pitched_data(param.from.image); + } + return compat::detail::memcpy(q, to, param.to.pos, from, param.from.pos, + param.size); +} +} // namespace detail +} // namespace experimental + +/// Allocate memory block on the device. +/// \param num_bytes Number of bytes to allocate. +/// \param q Queue to execute the allocate task. +/// \returns A pointer to the newly allocated memory. +static inline void *malloc(size_t num_bytes, + sycl::queue q = get_default_queue()) { + return detail::malloc(num_bytes, q); +} + +/// Allocate memory block on the device. +/// \param T Datatype to allocate +/// \param count Number of elements to allocate. +/// \param q Queue to execute the allocate task. +/// \returns A pointer to the newly allocated memory. +template +static inline T *malloc(size_t count, sycl::queue q = get_default_queue()) { + return static_cast(detail::malloc(count * sizeof(T), q)); +} + +/// Allocate memory block on the host. +/// \param num_bytes Number of bytes to allocate. +/// \param q Queue to execute the allocate task. +/// \returns A pointer to the newly allocated memory. +static inline void *malloc_host(size_t num_bytes, + sycl::queue q = get_default_queue()) { + return sycl::malloc_host(num_bytes, q); +} + +/// Allocate memory block on the host. +/// \param T Datatype to allocate +/// \param num_bytes Number of bytes to allocate. +/// \param q Queue to execute the allocate task. +/// \returns A pointer to the newly allocated memory. +template +static inline T *malloc_host(size_t count, + sycl::queue q = get_default_queue()) { + return static_cast(sycl::malloc_host(count * sizeof(T), q)); +} + +/// Allocate memory block of usm_shared memory. +/// \param num_bytes Number of bytes to allocate. +/// \param q Queue to execute the allocate task. +/// \returns A pointer to the newly allocated memory. +static inline void *malloc_shared(size_t num_bytes, + sycl::queue q = get_default_queue()) { + return sycl::malloc_shared(num_bytes, q); +} + +/// Allocate memory block of usm_shared memory. +/// \param num_bytes Number of bytes to allocate. +/// \param q Queue to execute the allocate task. +/// \returns A pointer to the newly allocated memory. +template +static inline T *malloc_shared(size_t count, + sycl::queue q = get_default_queue()) { + return static_cast(sycl::malloc_shared(count * sizeof(T), q)); +} + +/// Allocate memory block for 3D array on the device. +/// \param size Size of the memory block, in bytes. +/// \param q Queue to execute the allocate task. +/// \returns A pitched_data object which stores the memory info. +static inline pitched_data malloc(sycl::range<3> size, + sycl::queue q = get_default_queue()) { + pitched_data pitch(nullptr, 0, size.get(0), size.get(1)); + size_t pitch_size; + pitch.set_data_ptr( + detail::malloc(pitch_size, size.get(0), size.get(1), size.get(2), q)); + pitch.set_pitch(pitch_size); + return pitch; +} + +/// Allocate memory block for 2D array on the device. +/// \param [out] pitch Aligned size of x in bytes. +/// \param x Range in dim x. +/// \param y Range in dim y. +/// \param q Queue to execute the allocate task. +/// \returns A pointer to the newly allocated memory. +static inline void *malloc(size_t &pitch, size_t x, size_t y, + sycl::queue q = get_default_queue()) { + return detail::malloc(pitch, x, y, 1, q); +} + +namespace detail { + +inline void free(void *ptr, const sycl::queue &q) { + if (ptr) { +#ifdef COMPAT_USM_LEVEL_NONE + detail::mem_mgr::instance().mem_free(ptr); +#else + sycl::free(ptr, q.get_context()); +#endif // COMPAT_USM_LEVEL_NONE + } +} +} // namespace detail + +/// Wait on the queue \p q and free the memory \p ptr. +/// \param ptr Point to free. +/// \param q Queue to execute the free task. +/// \returns no return value. +static inline void wait_and_free(void *ptr, + sycl::queue q = get_default_queue()) { + get_current_device().queues_wait_and_throw(); + q.wait(); + if (ptr) { + detail::free(ptr, q); + } +} + +// Anonymous namespace to disable ADL for functions which might clash (memcpy, +// memset, free) +namespace { +/// Free the memory \p ptr on the default queue without synchronizing +/// \param ptr Point to free. +/// \returns no return value. +static inline void free(void *ptr, sycl::queue q = get_default_queue()) { + detail::free(ptr, q); +} +} // namespace + +/// Enqueues the release of all pointers in /p pointers on the /p q. +/// The command waits on all passed /p events and returns an event that +/// track the commands execution on the queue. +/// +/// \param pointers The pointers point to the device memory requested to be +/// freed. +/// \param events The events to be waited on. +/// \param q The sycl::queue the memory relates to. +// Can't be static due to the friend declaration in the memory header. +inline sycl::event enqueue_free(const std::vector &pointers, + const std::vector &events, + sycl::queue q = get_default_queue()) { + auto event = q.submit( + [&pointers, &events, &q](sycl::handler &cgh) { + cgh.depends_on(events); + cgh.host_task([=]() { + for (auto p : pointers) + detail::free(p, q); + }); + }); + get_current_device().add_event(event); + return event; +} + +namespace { +/// Synchronously copies \p size bytes from the address specified by \p from_ptr +/// to the address specified by \p to_ptr. The function will +/// return after the copy is completed. +/// +/// \param to_ptr Pointer to destination memory address. +/// \param from_ptr Pointer to source memory address. +/// \param size Number of bytes to be copied. +/// \param q Queue to execute the copy task. +/// \returns no return value. +static void memcpy(void *to_ptr, const void *from_ptr, size_t size, + sycl::queue q = get_default_queue()) { + detail::memcpy(q, to_ptr, from_ptr, size).wait(); +} + +} // namespace + +/// Asynchronously copies \p size bytes from the address specified by \p +/// from_ptr to the address specified by \p to_ptr. The return of the function +/// does NOT guarantee the copy is completed. +/// +/// \param to_ptr Pointer to destination memory address. +/// \param from_ptr Pointer to source memory address. +/// \param size Number of bytes to be copied. +/// \param q Queue to execute the copy task. +/// \returns no return value. +static sycl::event memcpy_async(void *to_ptr, const void *from_ptr, size_t size, + sycl::queue q = get_default_queue()) { + return detail::memcpy(q, to_ptr, from_ptr, size); +} + +/// Asynchronously copies \p count T's from the address specified by \p +/// from_ptr to the address specified by \p to_ptr. The return of the function +/// does NOT guarantee the copy is completed. +/// +/// \tparam T Datatype to be copied. +/// \param to_ptr Pointer to destination memory address. +/// \param from_ptr Pointer to source memory address. +/// \param count Number of T to be copied. +/// \param q Queue to execute the copy task. +/// \returns no return value. +template +static sycl::event +memcpy_async(type_identity_t *to_ptr, const type_identity_t *from_ptr, + size_t count, sycl::queue q = get_default_queue()) { + return detail::memcpy(q, static_cast(to_ptr), + static_cast(from_ptr), count * sizeof(T)); +} + +namespace { +/// Synchronously copies \p count T's from the address specified by \p from_ptr +/// to the address specified by \p to_ptr. The function will +/// return after the copy is completed. +/// +/// \tparam T Datatype to be copied. +/// \param to_ptr Pointer to destination memory address. +/// \param from_ptr Pointer to source memory address. +/// \param count Number of T to be copied. +/// \param q Queue to execute the copy task. +/// \returns no return value. +template +static void memcpy(type_identity_t *to_ptr, + const type_identity_t *from_ptr, size_t count, + sycl::queue q = get_default_queue()) { + detail::memcpy(q, static_cast(to_ptr), + static_cast(from_ptr), count * sizeof(T)) + .wait(); +} + +/// Synchronously copies 2D matrix specified by \p x and \p y from the address +/// specified by \p from_ptr to the address specified by \p to_ptr, while \p +/// from_pitch and \p to_pitch are the range of dim x in bytes of the matrix +/// specified by \p from_ptr and \p to_ptr. The function will return after the +/// copy is completed. +/// +/// \param to_ptr Pointer to destination memory address. +/// \param to_pitch Range of dim x in bytes of destination matrix. +/// \param from_ptr Pointer to source memory address. +/// \param from_pitch Range of dim x in bytes of source matrix. +/// \param x Range of dim x of matrix to be copied. +/// \param y Range of dim y of matrix to be copied. +/// \param q Queue to execute the copy task. +/// \returns no return value. +static inline void memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr, + size_t from_pitch, size_t x, size_t y, + sycl::queue q = get_default_queue()) { + sycl::event::wait( + detail::memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y)); +} + +} // namespace + +/// Asynchronously copies 2D matrix specified by \p x and \p y from the address +/// specified by \p from_ptr to the address specified by \p to_ptr, while \p +/// \p from_pitch and \p to_pitch are the range of dim x in bytes of the matrix +/// specified by \p from_ptr and \p to_ptr. The return of the function does NOT +/// guarantee the copy is completed. +/// +/// \param to_ptr Pointer to destination memory address. +/// \param to_pitch Range of dim x in bytes of destination matrix. +/// \param from_ptr Pointer to source memory address. +/// \param from_pitch Range of dim x in bytes of source matrix. +/// \param x Range of dim x of matrix to be copied. +/// \param y Range of dim y of matrix to be copied. +/// \param q Queue to execute the copy task. +/// \returns An event representing the memcpy operation. +static inline sycl::event memcpy_async(void *to_ptr, size_t to_pitch, + const void *from_ptr, size_t from_pitch, + size_t x, size_t y, + sycl::queue q = get_default_queue()) { + auto events = detail::memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y); + return detail::combine_events(events, q); +} + +namespace { +/// Synchronously copies a subset of a 3D matrix specified by \p to to another +/// 3D matrix specified by \p from. The from and to position info are specified +/// by \p from_pos and \p to_pos The copied matrix size is specified by \p size. +// The function will return after the copy is completed. +/// +/// \param to Destination matrix info. +/// \param to_pos Position of destination. +/// \param from Source matrix info. +/// \param from_pos Position of destination. +/// \param size Range of the submatrix to be copied. +/// \param q Queue to execute the copy task. +/// \returns no return value. +static inline void memcpy(pitched_data to, sycl::id<3> to_pos, + pitched_data from, sycl::id<3> from_pos, + sycl::range<3> size, + sycl::queue q = get_default_queue()) { + sycl::event::wait(detail::memcpy(q, to, to_pos, from, from_pos, size)); +} +} // namespace + +/// Asynchronously copies a subset of a 3D matrix specified by \p to to another +/// 3D matrix specified by \p from. The from and to position info are specified +/// by \p from_pos and \p to_pos The copied matrix size is specified by \p size. +/// The return of the function does NOT guarantee the copy is completed. +/// +/// \param to Destination matrix info. +/// \param to_pos Position of destination. +/// \param from Source matrix info. +/// \param from_pos Position of destination. +/// \param size Range of the submatrix to be copied. +/// \param q Queue to execute the copy task. +/// \returns An event representing the memcpy operation. +static inline sycl::event memcpy_async(pitched_data to, sycl::id<3> to_pos, + pitched_data from, sycl::id<3> from_pos, + sycl::range<3> size, + sycl::queue q = get_default_queue()) { + auto events = detail::memcpy(q, to, to_pos, from, from_pos, size); + return detail::combine_events(events, q); +} + +namespace { +/// Synchronously sets \p pattern to the first \p count elements starting from +/// \p dev_ptr. The function will return after the fill operation is completed. +/// +/// \tparam T Datatype of the value to be set. +/// \param dev_ptr Pointer to the device memory address. +/// \param pattern Pattern of type \p T to be set. +/// \param count Number of elements to be set to the patten. +/// \param q The queue in which the operation is done. +/// \returns no return value. +template +static void inline fill(void *dev_ptr, const T &pattern, size_t count, + sycl::queue q = get_default_queue()) { + detail::fill(q, dev_ptr, pattern, count).wait(); +} +} // namespace + +/// Asynchronously sets \p pattern to the first \p count elements starting from +/// \p dev_ptr. +/// The return of the function does NOT guarantee the fill operation is +/// completed. +/// +/// \tparam T Datatype of the pattern to be set. +/// \param dev_ptr Pointer to the device memory address. +/// \param pattern Pattern of type \p T to be set. +/// \param count Number of elements to be set to the patten. +/// \param q The queue in which the operation is done. +/// \returns An event representing the fill operation. +template +static sycl::event inline fill_async(void *dev_ptr, const T &pattern, + size_t count, + sycl::queue q = get_default_queue()) { + return detail::fill(q, dev_ptr, pattern, count); +} + +namespace experimental { + +/// [UNSUPPORTED] Synchronously copies 2D/3D memory data specified by \p param . +/// The function will return after the copy is completed. +/// +/// \param param Memory copy parameters. +/// \param q Queue to execute the copy task. +/// \returns no return value. +static inline void memcpy(const memcpy_parameter ¶m, + sycl::queue q = get_default_queue()) { + sycl::event::wait(compat::experimental::detail::memcpy(q, param)); +} + +/// [UNSUPPORTED] Asynchronously copies 2D/3D memory data specified by \p param +/// . The return of the function does NOT guarantee the copy is completed. +/// +/// \param param Memory copy parameters. +/// \param q Queue to execute the copy task. +/// \returns no return value. +static inline void memcpy_async(const memcpy_parameter ¶m, + sycl::queue q = get_default_queue()) { + compat::experimental::detail::memcpy(q, param); +} +} // namespace experimental + +namespace { +/// Synchronously sets \p value to the first \p size bytes starting from \p +/// dev_ptr. The function will return after the memset operation is completed. +/// +/// \param dev_ptr Pointer to the device memory address. +/// \param value Value to be set. +/// \param size Number of bytes to be set to the value. +/// \param q The queue in which the operation is done. +/// \returns no return value. +static void memset(void *dev_ptr, int value, size_t size, + sycl::queue q = get_default_queue()) { + detail::memset(q, dev_ptr, value, size).wait(); +} +} // namespace + +/// \brief Sets 2 bytes data \p value to the first \p size elements starting +/// from \p dev_ptr in \p q synchronously. +/// \param [in] dev_ptr Pointer to the virtual device memory address. +/// \param [in] value The value to be set. +/// \param [in] size Number of elements to be set to the value. +/// \param [in] q The queue in which the operation is done. +static inline void memset_d16(void *dev_ptr, unsigned short value, size_t size, + sycl::queue q = get_default_queue()) { + detail::fill(q, dev_ptr, value, size).wait(); +} + +/// \brief Sets 4 bytes data \p value to the first \p size elements starting +/// from \p dev_ptr in \p q synchronously. +/// \param [in] dev_ptr Pointer to the virtual device memory address. +/// \param [in] value The value to be set. +/// \param [in] size Number of elements to be set to the value. +/// \param [in] q The queue in which the operation is done. +static inline void memset_d32(void *dev_ptr, unsigned int value, size_t size, + sycl::queue q = get_default_queue()) { + detail::fill(q, dev_ptr, value, size).wait(); +} + +/// \brief Sets 1 byte data \p value to the first \p size elements starting +/// from \p dev_ptr in \p q asynchronously. +/// \param dev_ptr Pointer to the device memory address. +/// \param value Value to be set. +/// \param size Number of bytes to be set to the value. +/// \returns An event representing the memset operation. +static inline sycl::event memset_async(void *dev_ptr, int value, size_t size, + sycl::queue q = get_default_queue()) { + return detail::memset(q, dev_ptr, value, size); +} + +/// \brief Sets 2 bytes data \p value to the first \p size elements starting +/// from \p dev_ptr in \p q asynchronously. +/// \param [in] dev_ptr Pointer to the virtual device memory address. +/// \param [in] value The value to be set. +/// \param [in] size Number of elements to be set to the value. +/// \param [in] q The queue in which the operation is done. +/// \returns An event representing the memset operation. +static inline sycl::event +memset_d16_async(void *dev_ptr, unsigned short value, size_t size, + sycl::queue q = get_default_queue()) { + return detail::fill(q, dev_ptr, value, size); +} + +/// \brief Sets 4 bytes data \p value to the first \p size elements starting +/// from \p dev_ptr in \p q asynchronously. +/// \param [in] dev_ptr Pointer to the virtual device memory address. +/// \param [in] value The value to be set. +/// \param [in] size Number of elements to be set to the value. +/// \param [in] q The queue in which the operation is done. +/// \returns An event representing the memset operation. +static inline sycl::event +memset_d32_async(void *dev_ptr, unsigned int value, size_t size, + sycl::queue q = get_default_queue()) { + return detail::fill(q, dev_ptr, value, size); +} + +namespace { +/// \brief Sets 1 byte data \p val to the pitched 2D memory region pointed by \p +/// ptr in \p q synchronously. +/// \param [in] ptr Pointer to the virtual device memory. +/// \param [in] pitch The pitch size by number of elements, including padding. +/// \param [in] val The value to be set. +/// \param [in] x The width of memory region by number of elements. +/// \param [in] y The height of memory region by number of elements. +/// \param [in] q The queue in which the operation is done. +static inline void memset(void *ptr, size_t pitch, int val, size_t x, size_t y, + sycl::queue q = get_default_queue()) { + sycl::event::wait(detail::memset(q, ptr, pitch, val, x, y)); +} +} // namespace + +/// \brief Sets 2 bytes data \p val to the pitched 2D memory region pointed by +/// ptr in \p q synchronously. +/// \param [in] ptr Pointer to the virtual device memory. +/// \param [in] pitch The pitch size by number of elements, including padding. +/// \param [in] val The value to be set. +/// \param [in] x The width of memory region by number of elements. +/// \param [in] y The height of memory region by number of elements. +/// \param [in] q The queue in which the operation is done. +static inline void memset_d16(void *ptr, size_t pitch, unsigned short val, + size_t x, size_t y, + sycl::queue q = get_default_queue()) { + sycl::event::wait(detail::memset(q, ptr, pitch, val, x, y)); +} + +/// \brief Sets 4 bytes data \p val to the pitched 2D memory region pointed by +/// ptr in \p q synchronously. +/// \param [in] ptr Pointer to the virtual device memory. +/// \param [in] pitch The pitch size by number of elements, including padding. +/// \param [in] val The value to be set. +/// \param [in] x The width of memory region by number of elements. +/// \param [in] y The height of memory region by number of elements. +/// \param [in] q The queue in which the operation is done. +static inline void memset_d32(void *ptr, size_t pitch, unsigned int val, + size_t x, size_t y, + sycl::queue q = get_default_queue()) { + sycl::event::wait(detail::memset(q, ptr, pitch, val, x, y)); +} + +/// \brief Sets 1 byte data \p val to the pitched 2D memory region pointed by \p +/// ptr in \p q asynchronously. +/// \param [in] ptr Pointer to the virtual device memory. +/// \param [in] pitch The pitch size by number of elements, including padding. +/// \param [in] val The value to be set. +/// \param [in] x The width of memory region by number of elements. +/// \param [in] y The height of memory region by number of elements. +/// \param [in] q The queue in which the operation is done. +/// \returns An event representing the memset operation. +static inline sycl::event memset_async(void *ptr, size_t pitch, int val, + size_t x, size_t y, + sycl::queue q = get_default_queue()) { + + auto events = detail::memset(q, ptr, pitch, val, x, y); + return detail::combine_events(events, q); +} + +/// \brief Sets 2 bytes data \p val to the pitched 2D memory region pointed by +/// \p ptr in \p q asynchronously. +/// \param [in] ptr Pointer to the virtual device memory. +/// \param [in] pitch The pitch size by number of elements, including padding. +/// \param [in] val The value to be set. +/// \param [in] x The width of memory region by number of elements. +/// \param [in] y The height of memory region by number of elements. +/// \param [in] q The queue in which the operation is done. +/// \returns An event representing the memset operation. +static inline sycl::event +memset_d16_async(void *ptr, size_t pitch, unsigned short val, size_t x, + size_t y, sycl::queue q = get_default_queue()) { + auto events = detail::memset(q, ptr, pitch, val, x, y); + return detail::combine_events(events, q); +} + +/// \brief Sets 4 bytes data \p val to the pitched 2D memory region pointed by +/// \p ptr in \p q asynchronously. +/// \param [in] ptr Pointer to the virtual device memory. +/// \param [in] pitch The pitch size by number of elements, including padding. +/// \param [in] val The value to be set. +/// \param [in] x The width of memory region by number of elements. +/// \param [in] y The height of memory region by number of elements. +/// \param [in] q The queue in which the operation is done. +/// \returns An event representing the memset operation. +static inline sycl::event +memset_d32_async(void *ptr, size_t pitch, unsigned int val, size_t x, size_t y, + sycl::queue q = get_default_queue()) { + auto events = detail::memset(q, ptr, pitch, val, x, y); + return detail::combine_events(events, q); +} + +namespace { +/// Sets \p value to the 3D memory region specified by \p pitch in \p q. \p size +/// specify the setted 3D memory size. The function will return after the +/// memset operation is completed. +/// +/// \param pitch Specify the 3D memory region. +/// \param value Value to be set. +/// \param size The setted 3D memory size. +/// \param q The queue in which the operation is done. +/// \returns no return value. +static inline void memset(pitched_data pitch, int val, sycl::range<3> size, + sycl::queue q = get_default_queue()) { + sycl::event::wait(detail::memset(q, pitch, val, size)); +} +} // namespace + +/// Sets \p value to the 3D memory region specified by \p pitch in \p q. \p size +/// specify the setted 3D memory size. The return of the function does NOT +/// guarantee the memset operation is completed. +/// +/// \param pitch Specify the 3D memory region. +/// \param value Value to be set. +/// \param size The setted 3D memory size. +/// \param q The queue in which the operation is done. +/// \returns An event representing the memset operation. +static inline sycl::event memset_async(pitched_data pitch, int val, + sycl::range<3> size, + sycl::queue q = get_default_queue()) { + auto events = detail::memset(q, pitch, val, size); + return detail::combine_events(events, q); +} + +/// accessor used as device function parameter. +template class accessor; +template class accessor { +public: + using memory_t = detail::memory_traits; + using element_t = typename memory_t::element_t; + using pointer_t = typename memory_t::pointer_t; + using accessor_t = typename memory_t::template accessor_t<3>; + accessor(pointer_t data, const sycl::range<3> &in_range) + : _data(data), _range(in_range) {} + template + accessor(typename std::enable_if::type &acc) + : accessor(acc, acc.get_range()) {} + accessor(const accessor_t &acc, const sycl::range<3> &in_range) + : accessor( + acc.template get_multi_ptr().get(), + in_range) {} + accessor operator[](size_t index) const { + sycl::range<2> sub(_range.get(1), _range.get(2)); + return accessor(_data + index * sub.size(), sub); + } + + pointer_t get_ptr() const { return _data; } + +private: + pointer_t _data; + sycl::range<3> _range; +}; +template class accessor { +public: + using memory_t = detail::memory_traits; + using element_t = typename memory_t::element_t; + using pointer_t = typename memory_t::pointer_t; + using accessor_t = typename memory_t::template accessor_t<2>; + accessor(pointer_t data, const sycl::range<2> &in_range) + : _data(data), _range(in_range) {} + template + accessor(typename std::enable_if::type &acc) + : accessor(acc, acc.get_range()) {} + accessor(const accessor_t &acc, const sycl::range<2> &in_range) + : accessor( + acc.template get_multi_ptr().get(), + in_range) {} + + pointer_t operator[](size_t index) const { + return _data + _range.get(1) * index; + } + + pointer_t get_ptr() const { return _data; } + +private: + pointer_t _data; + sycl::range<2> _range; +}; + +/// Device variable with address space of shared or global. +// TODO(compat-lib-reviewers): This doesn't yet support multi-device (ptr +// per device) +template class device_memory { +public: + using accessor_t = + typename detail::memory_traits::template accessor_t; + using value_t = typename detail::memory_traits::value_t; + using compat_accessor_t = compat::accessor; + + device_memory(sycl::queue q = get_default_queue()) + : device_memory(sycl::range(1), q) {} + + /// Constructor of 1-D array with initializer list + device_memory(const sycl::range &in_range, + std::initializer_list &&init_list, + sycl::queue q = get_default_queue()) + : device_memory(in_range, q) { + assert(init_list.size() <= in_range.size()); + _host_ptr = (value_t *)std::malloc(_size); + std::memset(_host_ptr, 0, _size); + std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T)); + } + + /// Constructor of 2-D array with initializer list + template + device_memory( + const typename std::enable_if>::type &in_range, + std::initializer_list> &&init_list, + sycl::queue q = get_default_queue()) + : device_memory(in_range, q) { + assert(init_list.size() <= in_range[0]); + _host_ptr = (value_t *)std::malloc(_size); + std::memset(_host_ptr, 0, _size); + auto tmp_data = _host_ptr; + for (auto sub_list : init_list) { + assert(sub_list.size() <= in_range[1]); + std::memcpy(tmp_data, sub_list.begin(), sub_list.size() * sizeof(T)); + tmp_data += in_range[1]; + } + } + + /// Constructor with range + device_memory(const sycl::range &range_in, + sycl::queue q = get_default_queue()) + : _size(range_in.size() * sizeof(T)), _range(range_in), _reference(false), + _host_ptr(nullptr), _device_ptr(nullptr), _q(q) { + static_assert((Memory == memory_region::global) || + (Memory == memory_region::constant) || + (Memory == memory_region::usm_shared), + "device memory region should be global, constant or shared"); + // Make sure that singleton class dev_mgr will destruct later than this. + detail::dev_mgr::instance(); +#ifdef COMPAT_USM_LEVEL_NONE + detail::mem_mgr::instance(); +#endif + } + + /// Constructor with range + // enable_if_t SFINAE to avoid ambiguity with + // device_memory(Args... Arguments, sycl::queue q) + template > + device_memory(Args... Arguments) + : device_memory(sycl::range(Arguments...), + get_default_queue()) {} + + /// Constructor with range and queue + template + device_memory(Args... Arguments, sycl::queue q) + : device_memory(sycl::range(Arguments...), q) {} + + ~device_memory() { + if (_device_ptr && !_reference) + compat::free(_device_ptr, _q); + if (_host_ptr) + std::free(_host_ptr); + } + + /// Allocate memory with the queue specified in the constuctor, and init + /// memory if has initial value + void init() { init(_q); } + /// Allocate memory with specified queue, and init memory if has initial + /// value. + void init(sycl::queue q) { + if (_device_ptr) + return; + if (!_size) + return; + allocate_device(q); + if (_host_ptr) + detail::memcpy(q, _device_ptr, _host_ptr, _size); + } + + /// The variable is assigned to a device pointer. + void assign(value_t *src, size_t size) { + this->~device_memory(); + new (this) device_memory(src, size, _q); + } + + // Get memory pointer of the memory object, a device USM pointer. + value_t *get_ptr() { return get_ptr(_q); } + + // Get memory pointer of the memory object, a device USM pointer. + value_t *get_ptr(sycl::queue q) { + init(q); + return _device_ptr; + } + + /// Get the device memory object size in bytes. + size_t get_size() { return _size; } + + template + typename std::enable_if::type &operator[](size_t index) { + init(); +#ifdef COMPAT_USM_LEVEL_NONE + return compat::get_buffer::type>( + _device_ptr) + .template get_access()[index]; +#else + return _device_ptr[index]; +#endif // COMPAT_USM_LEVEL_NONE + } + +#ifdef COMPAT_USM_LEVEL_NONE + /// Get sycl::accessor for the device memory object when usm is not used. + accessor_t get_access(sycl::handler &cgh) { + return get_buffer(_device_ptr) + .template reinterpret(_range) + .template get_access::mode, + detail::memory_traits::target>(cgh); + } +#else + /// Get compat_accessor with dimension info for the device memory object + /// when usm is used and dimension is greater than 1. + template + typename std::enable_if::type + get_access(sycl::handler &cgh) { + return compat_accessor_t((T *)_device_ptr, _range); + } +#endif // COMPAT_USM_LEVEL_NONE + +private: + device_memory(value_t *memory_ptr, size_t size, + sycl::queue q = get_default_queue()) + : _size(size), _range(size / sizeof(T)), _reference(true), + _device_ptr(memory_ptr), _q(q) {} + + void allocate_device(sycl::queue q) { +#ifndef COMPAT_USM_LEVEL_NONE + if (Memory == memory_region::usm_shared) { + _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(), + q.get_context()); + return; + } +#ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY + if (Memory == memory_region::constant) { + _device_ptr = (value_t *)sycl::malloc_device( + _size, q.get_device(), q.get_context(), + sycl::ext::oneapi::property::usm::device_read_only()); + return; + } +#endif +#endif + _device_ptr = (value_t *)detail::malloc(_size, q); + } + + size_t _size; + sycl::range _range; + bool _reference; + value_t *_host_ptr; + value_t *_device_ptr; + sycl::queue _q; +}; +template +class device_memory : public device_memory { +public: + using base = device_memory; + using value_t = typename base::value_t; + using accessor_t = + typename detail::memory_traits::template accessor_t<0>; + + /// Constructor with initial value. + device_memory(const value_t &val, sycl::queue q = get_default_queue()) + : base(sycl::range<1>(1), {val}, q) {} + + /// Default constructor + device_memory(sycl::queue q = get_default_queue()) : base(1, q) {} +#ifdef COMPAT_USM_LEVEL_NONE + /// Get sycl::accessor for the device memory object when usm is not used. + accessor_t get_access(sycl::handler &cgh) { + auto buf = get_buffer(base::get_ptr()) + .template reinterpret(sycl::range<1>(1)); + return accessor_t(buf, cgh); + } +#endif // COMPAT_USM_LEVEL_NONE +}; + +template +using global_memory = device_memory; +template +using constant_memory = device_memory; +template +using shared_memory = device_memory; + +class pointer_attributes { +public: + void init(const void *ptr, sycl::queue q = get_default_queue()) { +#ifdef COMPAT_USM_LEVEL_NONE + throw std::runtime_error( + "[Compat] pointer_attributes: only works for USM pointer."); +#else + memory_type = sycl::get_pointer_type(ptr, q.get_context()); + device_pointer = (memory_type != sycl::usm::alloc::unknown) ? ptr : nullptr; + host_pointer = (memory_type != sycl::usm::alloc::unknown) && + (memory_type != sycl::usm::alloc::device) + ? ptr + : nullptr; + sycl::device device_obj = sycl::get_pointer_device(ptr, q.get_context()); + device_id = detail::dev_mgr::instance().get_device_id(device_obj); +#endif // COMPAT_USM_LEVEL_NONE + } + + sycl::usm::alloc get_memory_type() { return memory_type; } + + const void *get_device_pointer() { return device_pointer; } + + const void *get_host_pointer() { return host_pointer; } + + bool is_memory_shared() { return memory_type == sycl::usm::alloc::shared; } + + unsigned int get_device_id() { return device_id; } + +private: + sycl::usm::alloc memory_type = sycl::usm::alloc::unknown; + const void *device_pointer = nullptr; + const void *host_pointer = nullptr; + unsigned int device_id = 0; +}; + +} // namespace compat diff --git a/tools/util/include/compat/traits.hpp b/tools/util/include/compat/traits.hpp new file mode 100644 index 0000000000..a4c293822c --- /dev/null +++ b/tools/util/include/compat/traits.hpp @@ -0,0 +1,294 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCL compatibility extension + * + * traits.hpp + * + * Description: + * Type traits for the SYCL compatibility extension + **************************************************************************/ + +#pragma once + +#include +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +#include +#endif +#include +#include +#include +#include +#include +#include + +namespace compat { + +// Equivalent to C++20's std::type_identity (used to create non-deduced +// contexts) +template struct type_identity { + using type = T; +}; +template using type_identity_t = typename type_identity::type; + +// Defines the operand type for arithemtic operations on T. This is identity +// for all types except pointers, for which it is std::ptrdiff_t +template struct arith { + using type = std::conditional_t, std::ptrdiff_t, T>; +}; +template using arith_t = typename arith::type; + +// Traits to check device function signature matches args (with or without local +// mem) +template +struct device_fn_invocable : std::is_invocable {}; + +template +struct device_fn_lmem_invocable + : std::is_invocable {}; + +template +constexpr inline bool args_compatible = + std::conditional_t, + device_fn_invocable>::value; + +namespace detail { + +// Trait for identifying sycl::range and sycl::nd_range. +template struct is_range : std::false_type {}; +template struct is_range> : std::true_type {}; + +template constexpr bool is_range_v = is_range::value; + +template struct is_nd_range : std::false_type {}; +template struct is_nd_range> : std::true_type {}; + +template constexpr bool is_nd_range_v = is_nd_range::value; + +template +constexpr bool is_range_or_nd_range_v = + std::disjunction_v, is_nd_range>; + +// Trait range_to_item_t to convert nd_range -> nd_item, range -> item +template struct range_to_item_map; +template struct range_to_item_map> { + using ItemT = sycl::nd_item; +}; +template struct range_to_item_map> { + using ItemT = sycl::item; +}; + +template +using range_to_item_t = typename range_to_item_map::ItemT; + +} // namespace detail + +// Forward decls +namespace experimental { + +template struct kernel_properties; +template struct launch_properties; +struct local_mem_size; + +template +class launch_policy; +} // namespace experimental + +namespace experimental::detail { + +// Helper for tuple_template_index +template