diff --git a/libdevice/nativecpu_utils.cpp b/libdevice/nativecpu_utils.cpp index 4107663e6d1fd..404a6bff27f85 100644 --- a/libdevice/nativecpu_utils.cpp +++ b/libdevice/nativecpu_utils.cpp @@ -296,20 +296,7 @@ DefShuffleINTEL_All(uint8_t, i8, int8_t) DefShuffleINTEL_All(double, f64, double) DefShuffleINTEL_All(float, f32, float) -#define DefineShuffleVec(T, N, Sfx, MuxType) \ - using vt##T##N = sycl::vec::vector_t; \ - using vt##MuxType##N = sycl::vec::vector_t; \ - DefShuffleINTEL_All(vt##T##N, v##N##Sfx, vt##MuxType##N) - -#define DefineShuffleVec2to16(Type, Sfx, MuxType) \ - DefineShuffleVec(Type, 2, Sfx, MuxType) \ - DefineShuffleVec(Type, 4, Sfx, MuxType) \ - DefineShuffleVec(Type, 8, Sfx, MuxType) \ - DefineShuffleVec(Type, 16, Sfx, MuxType) - -DefineShuffleVec2to16(int32_t, i32, int32_t) -DefineShuffleVec2to16(uint32_t, i32, int32_t) -DefineShuffleVec2to16(float, f32, float) +// Vector versions of shuffle are generated by the FixABIBuiltinsSYCLNativeCPU pass #define Define2ArgForward(Type, Name, Callee)\ DEVICE_EXTERNAL Type Name(Type a, Type b) { return Callee(a,b);} diff --git a/llvm/include/llvm/SYCLLowerIR/FixABIMuxBuiltinsSYCLNativeCPU.h b/llvm/include/llvm/SYCLLowerIR/FixABIMuxBuiltinsSYCLNativeCPU.h new file mode 100644 index 0000000000000..9eea9a87fced2 --- /dev/null +++ b/llvm/include/llvm/SYCLLowerIR/FixABIMuxBuiltinsSYCLNativeCPU.h @@ -0,0 +1,29 @@ +//===---- FixABIMuxBuiltins.h - Fixup ABI issues with called mux builtins ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Creates calls to shuffle up/down/xor mux builtins taking into account ABI of the +// SYCL functions. For now this only is used for vector variants. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" + + +namespace llvm { + +class FixABIMuxBuiltinsPass final + : public llvm::PassInfoMixin { + public: + llvm::PreservedAnalyses run(llvm::Module &, llvm::ModuleAnalysisManager &); +}; + +} // namespace llvm + diff --git a/llvm/lib/SYCLNativeCPUUtils/CMakeLists.txt b/llvm/lib/SYCLNativeCPUUtils/CMakeLists.txt index bea5f1fac7cb1..bdbbff1e96cf0 100644 --- a/llvm/lib/SYCLNativeCPUUtils/CMakeLists.txt +++ b/llvm/lib/SYCLNativeCPUUtils/CMakeLists.txt @@ -3,7 +3,7 @@ add_llvm_component_library(LLVMSYCLNativeCPUUtils PrepareSYCLNativeCPU.cpp RenameKernelSYCLNativeCPU.cpp ConvertToMuxBuiltinsSYCLNativeCPU.cpp - + FixABIMuxBuiltinsSYCLNativeCPU.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/SYCLLowerIR diff --git a/llvm/lib/SYCLNativeCPUUtils/FixABIMuxBuiltinsSYCLNativeCPU.cpp b/llvm/lib/SYCLNativeCPUUtils/FixABIMuxBuiltinsSYCLNativeCPU.cpp new file mode 100644 index 0000000000000..b3ff7372b8d0f --- /dev/null +++ b/llvm/lib/SYCLNativeCPUUtils/FixABIMuxBuiltinsSYCLNativeCPU.cpp @@ -0,0 +1,226 @@ +//===-- FixABIMuxBuiltinsSYCLNativeCPU.cpp - Fixup mux ABI issues ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Creates calls to shuffle up/down/xor mux builtins taking into account ABI of +// the SYCL functions. For now this only is used for vector variants. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#define DEBUG_TYPE "fix-abi-mux-builtins" + +using namespace llvm; + +PreservedAnalyses FixABIMuxBuiltinsPass::run(Module &M, + ModuleAnalysisManager &AM) { + bool Changed = false; + + // Decide if a function needs updated and if so what parameters need changing, + // as well as the return value + auto FunctionNeedsFixing = + [](Function &F, + llvm::SmallVectorImpl> &Updates, + llvm::Type *&RetVal, std::string &MuxFuncNameToCall) { + if (!F.isDeclaration()) { + return false; + } + if (!F.getName().contains("__spirv_SubgroupShuffle")) { + return false; + } + Updates.clear(); + auto LIDvPos = F.getName().find("ELIDv"); + llvm::StringRef NameToMatch; + if (LIDvPos != llvm::StringRef::npos) { + // Add sizeof ELIDv to get num characters to match against + NameToMatch = F.getName().take_front(LIDvPos + 5); + } else { + return false; + } + + unsigned int StartIdx = 0; + unsigned int EndIdx = 1; + if (NameToMatch == "_Z32__spirv_SubgroupShuffleDownINTELIDv") { + MuxFuncNameToCall = "__mux_sub_group_shuffle_down_"; + } else if (NameToMatch == "_Z30__spirv_SubgroupShuffleUpINTELIDv") { + MuxFuncNameToCall = "__mux_sub_group_shuffle_up_"; + } else if (NameToMatch == "_Z28__spirv_SubgroupShuffleINTELIDv") { + MuxFuncNameToCall = "__mux_sub_group_shuffle_"; + EndIdx = 0; + } else if (NameToMatch == "_Z31__spirv_SubgroupShuffleXorINTELIDv") { + MuxFuncNameToCall = "__mux_sub_group_shuffle_xor_"; + EndIdx = 0; + } else { + return false; + } + + // We need to create the body for this. First we need to find out what + // the first arguments should be + llvm::StringRef RemainingName = + F.getName().drop_front(NameToMatch.size()); + std::string MuxFuncTypeStr = "UNKNOWN"; + + unsigned int VecWidth = 0; + if (RemainingName.consumeInteger(10, VecWidth)) { + return false; + } + if (!RemainingName.consume_front("_")) { + return false; + } + + char TypeCh = RemainingName[0]; + Type *BaseType = nullptr; + switch (TypeCh) { + case 'a': + case 'h': + BaseType = llvm::Type::getInt8Ty(F.getContext()); + MuxFuncTypeStr = "i8"; + break; + case 's': + case 't': + BaseType = llvm::Type::getInt16Ty(F.getContext()); + MuxFuncTypeStr = "i16"; + break; + + case 'i': + case 'j': + BaseType = llvm::Type::getInt32Ty(F.getContext()); + MuxFuncTypeStr = "i32"; + break; + case 'l': + case 'm': + BaseType = llvm::Type::getInt64Ty(F.getContext()); + MuxFuncTypeStr = "i64"; + break; + case 'f': + BaseType = llvm::Type::getFloatTy(F.getContext()); + MuxFuncTypeStr = "f32"; + break; + case 'd': + BaseType = llvm::Type::getDoubleTy(F.getContext()); + MuxFuncTypeStr = "f64"; + break; + default: + return false; + } + auto *VecType = llvm::FixedVectorType::get(BaseType, VecWidth); + RetVal = VecType; + + // Work out the mux function to call's type extension based on v##N##Sfx + MuxFuncNameToCall += "v"; + MuxFuncNameToCall += std::to_string(VecWidth); + MuxFuncNameToCall += MuxFuncTypeStr; + + unsigned int CurrentIndex = 0; + for (auto &Arg : F.args()) { + if (Arg.hasStructRetAttr()) { + StartIdx++; + EndIdx++; + } else { + if (CurrentIndex >= StartIdx && CurrentIndex <= EndIdx) { + if (Arg.getType() != VecType) { + Updates.push_back(std::pair( + CurrentIndex, VecType)); + } + } + } + CurrentIndex++; + } + return true; + }; + + llvm::SmallVector FuncsToProcess; + for (auto &F : M.functions()) { + FuncsToProcess.push_back(&F); + } + + for (auto *F : FuncsToProcess) { + llvm::SmallVector, 4> ArgUpdates; + llvm::Type *RetType = nullptr; + std::string MuxFuncNameToCall; + if (!FunctionNeedsFixing(*F, ArgUpdates, RetType, MuxFuncNameToCall)) { + continue; + } + if (!F->isDeclaration()) { + continue; + } + Changed = true; + IRBuilder<> IR(BasicBlock::Create(F->getContext(), "", F)); + + llvm::SmallVector Args; + unsigned int ArgIndex = 0; + unsigned int UpdateIndex = 0; + + for (auto &Arg : F->args()) { + if (!Arg.hasStructRetAttr()) { + if (UpdateIndex < ArgUpdates.size() && + std::get<0>(ArgUpdates[UpdateIndex]) == ArgIndex) { + Args.push_back(std::get<1>(ArgUpdates[UpdateIndex])); + UpdateIndex++; + } else { + Args.push_back(Arg.getType()); + } + } + ArgIndex++; + } + + FunctionType *FT = FunctionType::get(RetType, Args, false); + Function *NewFunc = + Function::Create(FT, F->getLinkage(), MuxFuncNameToCall, M); + llvm::SmallVector CallArgs; + auto NewFuncArgItr = NewFunc->args().begin(); + Argument *SretPtr = nullptr; + for (auto &Arg : F->args()) { + if (Arg.hasStructRetAttr()) { + SretPtr = &Arg; + } else { + if (Arg.getType() != (*NewFuncArgItr).getType()) { + if (Arg.getType()->isPointerTy()) { + Value *ArgLoad = IR.CreateLoad((*NewFuncArgItr).getType(), &Arg); + CallArgs.push_back(ArgLoad); + } else { + Value *ArgCast = IR.CreateBitCast(&Arg, (*NewFuncArgItr).getType()); + CallArgs.push_back(ArgCast); + } + } else { + CallArgs.push_back(&Arg); + } + NewFuncArgItr++; + } + } + + Value *Res = IR.CreateCall(NewFunc, CallArgs); + // If the return type is different to the initial function, then bitcast it + // unless it's void in which case we'd expect an StructRet parameter which + // needs stored to. + if (F->getReturnType() != RetType) { + if (F->getReturnType()->isVoidTy()) { + // If we don't have an StructRet parameter then something is wrong with + // the initial function + if (!SretPtr) { + llvm_unreachable( + "No struct ret pointer for Sub group shuffle function"); + } + + IR.CreateStore(Res, SretPtr); + } else { + Res = IR.CreateBitCast(Res, F->getReturnType()); + } + } + if (F->getReturnType()->isVoidTy()) { + IR.CreateRetVoid(); + } else { + IR.CreateRet(Res); + } + } + + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/llvm/lib/SYCLNativeCPUUtils/PipelineSYCLNativeCPU.cpp b/llvm/lib/SYCLNativeCPUUtils/PipelineSYCLNativeCPU.cpp index c78e0d9223ef9..52f78a7d494fe 100644 --- a/llvm/lib/SYCLNativeCPUUtils/PipelineSYCLNativeCPU.cpp +++ b/llvm/lib/SYCLNativeCPUUtils/PipelineSYCLNativeCPU.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// #include "llvm/SYCLLowerIR/ConvertToMuxBuiltinsSYCLNativeCPU.h" +#include "llvm/SYCLLowerIR/FixABIMuxBuiltinsSYCLNativeCPU.h" #include "llvm/SYCLLowerIR/PrepareSYCLNativeCPU.h" #include "llvm/SYCLLowerIR/RenameKernelSYCLNativeCPU.h" #include "llvm/SYCLLowerIR/SpecConstants.h" @@ -65,6 +66,7 @@ void llvm::sycl::utils::addSYCLNativeCPUBackendPasses( MPM.addPass(ConvertToMuxBuiltinsSYCLNativeCPUPass()); #ifdef NATIVECPU_USE_OCK MPM.addPass(compiler::utils::TransferKernelMetadataPass()); + MPM.addPass(FixABIMuxBuiltinsPass()); // Always enable vectorizer, unless explictly disabled or -O0 is set. if (OptLevel != OptimizationLevel::O0 && !SYCLNativeCPUNoVecz) { MAM.registerPass([] { return vecz::TargetInfoAnalysis(); }); diff --git a/llvm/lib/SYCLNativeCPUUtils/PrepareSYCLNativeCPU.cpp b/llvm/lib/SYCLNativeCPUUtils/PrepareSYCLNativeCPU.cpp index b3888db8a7b50..340df9d4e7264 100644 --- a/llvm/lib/SYCLNativeCPUUtils/PrepareSYCLNativeCPU.cpp +++ b/llvm/lib/SYCLNativeCPUUtils/PrepareSYCLNativeCPU.cpp @@ -464,13 +464,17 @@ PreservedAnalyses PrepareSYCLNativeCPUPass::run(Module &M, F->eraseFromParent(); ModuleChanged = true; } - for (auto It = M.begin(); It != M.end();) { - auto Curr = It++; - Function &F = *Curr; - if (F.getNumUses() == 0 && F.isDeclaration() && - F.getName().starts_with("__mux_")) { - F.eraseFromParent(); - ModuleChanged = true; + + // We do these twice because we create abi wrappers for mux which may show up + // before we've removed their user + for (unsigned int i = 0; i < 2; i++) { + for (auto It = M.begin(); It != M.end();) { + auto Curr = It++; + Function &F = *Curr; + if (F.getNumUses() == 0 && F.getName().starts_with("__mux_")) { + F.eraseFromParent(); + ModuleChanged = true; + } } } diff --git a/sycl/test/check_device_code/native_cpu/shuffle_abi.cpp b/sycl/test/check_device_code/native_cpu/shuffle_abi.cpp new file mode 100644 index 0000000000000..cdbaab90ce65c --- /dev/null +++ b/sycl/test/check_device_code/native_cpu/shuffle_abi.cpp @@ -0,0 +1,95 @@ +// REQUIRES: native_cpu_ock && linux + +// This doesn't test every possible case since it is quite slow to compile. +// long and double are not tested as it seems to generate loops in the code +// rather than vector versions. + +// RUN: %clangxx -DTYPE=int -DVEC_WIDTH=2 -DOPER=TF_SHIFT_UP -target x86_64-unknown-linux-gnu -fsycl -fsycl-targets=native_cpu -Xclang -sycl-std=2020 -mllvm -sycl-opt -mllvm -inline-threshold=500 -mllvm -sycl-native-cpu-no-vecz -mllvm -sycl-native-dump-device-ir %s | FileCheck --check-prefix UP_V2_INT %s +// RUN: %clangxx -DTYPE=short -DVEC_WIDTH=4 -DOPER=TF_SHIFT_DOWN -target x86_64-unknown-linux-gnu -fsycl -fsycl-targets=native_cpu -Xclang -sycl-std=2020 -mllvm -sycl-opt -mllvm -inline-threshold=500 -mllvm -sycl-native-cpu-no-vecz -mllvm -sycl-native-dump-device-ir %s | FileCheck --check-prefix DOWN_V4_SHORT %s +// RUN: %clangxx -DTYPE=char -DVEC_WIDTH=4 -DOPER=TF_SHIFT_XOR -target x86_64-unknown-linux-gnu -fsycl -fsycl-targets=native_cpu -Xclang -sycl-std=2020 -mllvm -sycl-opt -mllvm -inline-threshold=500 -mllvm -sycl-native-cpu-no-vecz -mllvm -sycl-native-dump-device-ir %s | FileCheck --check-prefix XOR_V4_CHAR %s +// RUN: %clangxx -DTYPE=float -DVEC_WIDTH=8 -DOPER=TF_SHIFT_UP -target x86_64-unknown-linux-gnu -fsycl -fsycl-targets=native_cpu -Xclang -sycl-std=2020 -mllvm -sycl-opt -mllvm -inline-threshold=500 -mllvm -sycl-native-cpu-no-vecz -mllvm -sycl-native-dump-device-ir %s | FileCheck --check-prefix UP_V8_FLOAT %s +// RUN: %clangxx -DTYPE="unsigned int" -DVEC_WIDTH=8 -DOPER=TF_SELECT -target x86_64-unknown-linux-gnu -fsycl -fsycl-targets=native_cpu -Xclang -sycl-std=2020 -mllvm -sycl-opt -mllvm -inline-threshold=500 -mllvm -sycl-native-cpu-no-vecz -mllvm -sycl-native-dump-device-ir %s | FileCheck --check-prefix SELECT_V8_SELECT_I32 %s + +// Tests that sub-group shuffles work even when abi is different to what is +// expected + +#include +#include +#include + +static constexpr size_t NumElems = VEC_WIDTH; +static constexpr size_t NumWorkItems = 64; + +// UP_V2_INT: double @_Z30__spirv_SubgroupShuffleUpINTELIDv2_iET_S1_S1_j(double noundef %[[ARG0:[0-9]+]], double noundef %[[ARG1:[0-9]+]] +// UP_V2_INT: %[[UPV2I32_BITCAST_OP0:[0-9]+]] = bitcast double %[[ARG0]] to <2 x i32> +// UP_V2_INT: %[[UPV2I32_BITCAST_OP1:[0-9]+]] = bitcast double %[[ARG1]] to <2 x i32> +// UP_V2_INT: %[[UPV2I32_CALL_SHUFFLE:[0-9]+]] = call <2 x i32> @__mux_sub_group_shuffle_up_v2i32(<2 x i32> %[[UPV2I32_BITCAST_OP0]], <2 x i32> %[[UPV2I32_BITCAST_OP1]] +// UP_V2_INT: %[[UPV2I32_BITCAST_RESULT:[0-9]+]] = bitcast <2 x i32> %[[UPV2I32_CALL_SHUFFLE]] to double +// UP_V2_INT: ret double %[[UPV2I32_BITCAST_RESULT]] + +// DOWN_V4_SHORT: double @_Z32__spirv_SubgroupShuffleDownINTELIDv4_sET_S1_S1_j(double noundef %[[ARG0:[0-9]+]], double noundef %[[ARG1:[0-9]+]] +// DOWN_V4_SHORT: %[[DOWNV4I16_BITCAST_OP0:[0-9]+]] = bitcast double %[[ARG0]] to <4 x i16> +// DOWN_V4_SHORT: %[[DOWNV4I16_BITCAST_OP1:[0-9]+]] = bitcast double %[[ARG1]] to <4 x i16> +// DOWN_V4_SHORT: %[[DOWNV4I16_CALL_SHUFFLE:[0-9]+]] = call <4 x i16> @__mux_sub_group_shuffle_down_v4i16(<4 x i16> %[[DOWNV4I16_BITCAST_OP0]], <4 x i16> %[[DOWNV4I16_BITCAST_OP1]] +// DOWN_V4_SHORT: %[[DOWNV4I16_BITCAST_RESULT:[0-9]+]] = bitcast <4 x i16> %[[DOWNV4I16_CALL_SHUFFLE]] to double +// DOWN_V4_SHORT: ret double %[[DOWNV4I16_BITCAST_RESULT]] + +// XOR_V4_CHAR: i32 @_Z31__spirv_SubgroupShuffleXorINTELIDv4_aET_S1_j(i32 noundef %[[ARG0:[0-9]+]], i32 +// XOR_V4_CHAR: %[[XORV4I8_BITCAST_OP0:[0-9]+]] = bitcast i32 %[[ARG0]] to <4 x i8> +// XOR_V4_CHAR: %[[XORV4I8_CALL_SHUFFLE:[0-9]+]] = call <4 x i8> @__mux_sub_group_shuffle_xor_v4i8(<4 x i8> %[[XORV4I8_BITCAST_OP0]], i32 +// XOR_V4_CHAR: %[[XORV4I8_BITCAST_RESULT:[0-9]+]] = bitcast <4 x i8> %[[XORV4I8_CALL_SHUFFLE]] to i32 +// XOR_V4_CHAR: ret i32 %[[XORV4I8_BITCAST_RESULT]] + +// UP_V8_FLOAT: <8 x float> @_Z30__spirv_SubgroupShuffleUpINTELIDv8_fET_S1_S1_j(ptr noundef byval(<8 x float>) align 32 %[[ARG0:[0-9]+]], ptr noundef byval(<8 x float>) align 32 %[[ARG1:[0-9]+]] +// UP_V8_FLOAT: %[[UPV8F32_BYVAL_LOAD_OP0:[0-9]+]] = load <8 x float>, ptr %[[ARG0]], align 32 +// UP_V8_FLOAT: %[[UPV8F32_BYVAL_LOAD_OP1:[0-9]+]] = load <8 x float>, ptr %[[ARG1]], align 32 +// UP_V8_FLOAT: %[[UPV8F32_CALL_SHUFFLE:[0-9]+]] = call <8 x float> @__mux_sub_group_shuffle_up_v8f32(<8 x float> %[[UPV8F32_BYVAL_LOAD_OP0]], <8 x float> %[[UPV8F32_BYVAL_LOAD_OP1]], i32 +// UP_V8_FLOAT: ret <8 x float> %[[UPV8F32_CALL_SHUFFLE:[0-9]+]] + +// SELECT_V8_SELECT_I32: <8 x i32> @_Z28__spirv_SubgroupShuffleINTELIDv8_jET_S1_j(ptr noundef byval(<8 x i32>) align 32 %[[ARG0:[0-9]+]], +// SELECT_V8_SELECT_I32: %[[SELV8I32_BYVAL_LOAD_OP0:[0-9]+]] = load <8 x i32>, ptr %[[ARG0]], align 32 +// SELECT_V8_SELECT_I32: %[[SELV8I32_CALL_SHUFFLE:[0-9]+]] = call <8 x i32> @__mux_sub_group_shuffle_v8i32(<8 x i32> %[[SELV8I32_BYVAL_LOAD_OP0]], i32 +// SELECT_V8_SELECT_I32: ret <8 x i32> %[[SELV8I32_CALL_SHUFFLE:[0-9]+]] + +enum TEST_FUNC_CHOICE { TF_SHIFT_DOWN, TF_SHIFT_UP, TF_SHIFT_XOR, TF_SELECT }; + +template +void ShuffleOpTest() { + sycl::queue Q; + + ShiftType ShiftRes[NumWorkItems]; + + { + sycl::buffer ShuffleResBuf{ShiftRes, NumWorkItems}; + + Q.submit([&](sycl::handler &CGH) { + sycl::accessor ShuffleRes{ShuffleResBuf, CGH, sycl::write_only}; + + CGH.parallel_for( + sycl::nd_range<1>{sycl::range<1>{NumWorkItems}, + sycl::range<1>{NumWorkItems}}, + [=](sycl::nd_item<1> It) { + int GID = It.get_global_linear_id(); + ShiftType ItemVal{0}; + for (int I = 0; I < NumElems; ++I) + ItemVal[I] = I; + + sycl::sub_group SG = It.get_sub_group(); + if (Choice == TF_SHIFT_DOWN) { + ShuffleRes[GID] = sycl::shift_group_left(SG, ItemVal); + } else if (Choice == TF_SHIFT_UP) { + ShuffleRes[GID] = sycl::shift_group_right(SG, ItemVal); + } else if (Choice == TF_SHIFT_XOR) { + ShuffleRes[GID] = sycl::permute_group_by_xor(SG, ItemVal, 1); + } else if (Choice == TF_SELECT) { + ShuffleRes[GID] = sycl::select_from_group(SG, ItemVal, 1); + } + }); + }); + } +} + +int main() { + ShuffleOpTest, OPER>(); + return 0; +}