Skip to content

Commit 0bfe1b2

Browse files
author
Colin Davidson
committed
[SYCL][NATIVE_CPU] Fill in any SYCL functions which require mapping to mux
Native cpu can make calls to mux builtins such as shuffle which are ABI compliant but are not what is expected by ock passes. This fixes them up by remove the vector versions from libnativecpu.cpp and using a pass to convert from parameters which relate to the ABI to calling the mux functions with the set interface unaffected by the ABI. This currently only handle a small number of cases for shuffle such as when a vector i2 is replaced with double or byval is used. It will be expanded over time as needed.
1 parent e3cfbfe commit 0bfe1b2

File tree

7 files changed

+365
-22
lines changed

7 files changed

+365
-22
lines changed

libdevice/nativecpu_utils.cpp

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -296,20 +296,7 @@ DefShuffleINTEL_All(uint8_t, i8, int8_t)
296296
DefShuffleINTEL_All(double, f64, double)
297297
DefShuffleINTEL_All(float, f32, float)
298298

299-
#define DefineShuffleVec(T, N, Sfx, MuxType) \
300-
using vt##T##N = sycl::vec<T, N>::vector_t; \
301-
using vt##MuxType##N = sycl::vec<MuxType, N>::vector_t; \
302-
DefShuffleINTEL_All(vt##T##N, v##N##Sfx, vt##MuxType##N)
303-
304-
#define DefineShuffleVec2to16(Type, Sfx, MuxType) \
305-
DefineShuffleVec(Type, 2, Sfx, MuxType) \
306-
DefineShuffleVec(Type, 4, Sfx, MuxType) \
307-
DefineShuffleVec(Type, 8, Sfx, MuxType) \
308-
DefineShuffleVec(Type, 16, Sfx, MuxType)
309-
310-
DefineShuffleVec2to16(int32_t, i32, int32_t)
311-
DefineShuffleVec2to16(uint32_t, i32, int32_t)
312-
DefineShuffleVec2to16(float, f32, float)
299+
// Vector versions of shuffle are generated by the FixABIBuiltinsSYCLNativeCPU pass
313300

314301
#define Define2ArgForward(Type, Name, Callee)\
315302
DEVICE_EXTERNAL Type Name(Type a, Type b) { return Callee(a,b);}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===---- FixABIMuxBuiltins.h - Fixup ABI issues with called mux builtins ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Creates calls to shuffle up/down/xor mux builtins taking into account ABI of the
10+
// SYCL functions. For now this only is used for vector variants.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#pragma once
15+
16+
#include "llvm/IR/Module.h"
17+
#include "llvm/IR/PassManager.h"
18+
19+
20+
namespace llvm {
21+
22+
class FixABIMuxBuiltinsPass final
23+
: public llvm::PassInfoMixin<FixABIMuxBuiltinsPass> {
24+
public:
25+
llvm::PreservedAnalyses run(llvm::Module &, llvm::ModuleAnalysisManager &);
26+
};
27+
28+
} // namespace llvm
29+

llvm/lib/SYCLNativeCPUUtils/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ add_llvm_component_library(LLVMSYCLNativeCPUUtils
33
PrepareSYCLNativeCPU.cpp
44
RenameKernelSYCLNativeCPU.cpp
55
ConvertToMuxBuiltinsSYCLNativeCPU.cpp
6-
6+
FixABIMuxBuiltinsSYCLNativeCPU.cpp
77

88
ADDITIONAL_HEADER_DIRS
99
${LLVM_MAIN_INCLUDE_DIR}/llvm/SYCLLowerIR
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
//===-- FixABIMuxBuiltinsSYCLNativeCPU.cpp - Fixup mux ABI issues ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Creates calls to shuffle up/down/xor mux builtins taking into account ABI of
10+
// the SYCL functions. For now this only is used for vector variants.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include <llvm/IR/IRBuilder.h>
15+
#include <llvm/IR/Module.h>
16+
#include <llvm/IR/Type.h>
17+
#include <llvm/SYCLLowerIR/FixABIMuxBuiltinsSYCLNativeCPU.h>
18+
19+
#define DEBUG_TYPE "fix-abi-mux-builtins"
20+
21+
using namespace llvm;
22+
23+
PreservedAnalyses FixABIMuxBuiltinsPass::run(Module &M,
24+
ModuleAnalysisManager &AM) {
25+
bool Changed = false;
26+
27+
// Decide if a function needs updated and if so what parameters need changing,
28+
// as well as the return value
29+
auto FunctionNeedsFixing =
30+
[](Function &F,
31+
llvm::SmallVectorImpl<std::pair<unsigned int, llvm::Type *>> &Updates,
32+
llvm::Type *&RetVal, std::string &MuxFuncNameToCall) {
33+
if (!F.isDeclaration()) {
34+
return false;
35+
}
36+
if (!F.getName().contains("__spirv_SubgroupShuffle")) {
37+
return false;
38+
}
39+
Updates.clear();
40+
auto LIDvPos = F.getName().find("ELIDv");
41+
llvm::StringRef NameToMatch;
42+
if (LIDvPos != llvm::StringRef::npos) {
43+
// Add sizeof ELIDv to get num characters to match against
44+
NameToMatch = F.getName().take_front(LIDvPos + 5);
45+
} else {
46+
return false;
47+
}
48+
49+
unsigned int StartIdx = 0;
50+
unsigned int EndIdx = 1;
51+
if (NameToMatch == "_Z32__spirv_SubgroupShuffleDownINTELIDv") {
52+
MuxFuncNameToCall = "__mux_sub_group_shuffle_down_";
53+
} else if (NameToMatch == "_Z30__spirv_SubgroupShuffleUpINTELIDv") {
54+
MuxFuncNameToCall = "__mux_sub_group_shuffle_up_";
55+
} else if (NameToMatch == "_Z28__spirv_SubgroupShuffleINTELIDv") {
56+
MuxFuncNameToCall = "__mux_sub_group_shuffle_";
57+
EndIdx = 0;
58+
} else if (NameToMatch == "_Z31__spirv_SubgroupShuffleXorINTELIDv") {
59+
MuxFuncNameToCall = "__mux_sub_group_shuffle_xor_";
60+
EndIdx = 0;
61+
} else {
62+
return false;
63+
}
64+
65+
// We need to create the body for this. First we need to find out what
66+
// the first arguments should be
67+
llvm::StringRef RemainingName =
68+
F.getName().drop_front(NameToMatch.size());
69+
std::string MuxFuncTypeStr = "UNKNOWN";
70+
71+
unsigned int VecWidth = 0;
72+
if (RemainingName.consumeInteger(10, VecWidth)) {
73+
return false;
74+
}
75+
if (!RemainingName.consume_front("_")) {
76+
return false;
77+
}
78+
79+
char TypeCh = RemainingName[0];
80+
Type *BaseType = nullptr;
81+
switch (TypeCh) {
82+
case 'a':
83+
case 'h':
84+
BaseType = llvm::Type::getInt8Ty(F.getContext());
85+
MuxFuncTypeStr = "i8";
86+
break;
87+
case 's':
88+
case 't':
89+
BaseType = llvm::Type::getInt16Ty(F.getContext());
90+
MuxFuncTypeStr = "i16";
91+
break;
92+
93+
case 'i':
94+
case 'j':
95+
BaseType = llvm::Type::getInt32Ty(F.getContext());
96+
MuxFuncTypeStr = "i32";
97+
break;
98+
case 'l':
99+
case 'm':
100+
BaseType = llvm::Type::getInt64Ty(F.getContext());
101+
MuxFuncTypeStr = "i64";
102+
break;
103+
case 'f':
104+
BaseType = llvm::Type::getFloatTy(F.getContext());
105+
MuxFuncTypeStr = "f32";
106+
break;
107+
case 'd':
108+
BaseType = llvm::Type::getDoubleTy(F.getContext());
109+
MuxFuncTypeStr = "f64";
110+
break;
111+
default:
112+
return false;
113+
}
114+
auto *VecType = llvm::FixedVectorType::get(BaseType, VecWidth);
115+
RetVal = VecType;
116+
117+
// Work out the mux function to call's type extension based on v##N##Sfx
118+
MuxFuncNameToCall += "v";
119+
MuxFuncNameToCall += std::to_string(VecWidth);
120+
MuxFuncNameToCall += MuxFuncTypeStr;
121+
122+
unsigned int CurrentIndex = 0;
123+
for (auto &Arg : F.args()) {
124+
if (Arg.hasStructRetAttr()) {
125+
StartIdx++;
126+
EndIdx++;
127+
} else {
128+
if (CurrentIndex >= StartIdx && CurrentIndex <= EndIdx) {
129+
if (Arg.getType() != VecType) {
130+
Updates.push_back(std::pair<unsigned int, llvm::Type *>(
131+
CurrentIndex, VecType));
132+
}
133+
}
134+
}
135+
CurrentIndex++;
136+
}
137+
return true;
138+
};
139+
140+
llvm::SmallVector<Function *, 4> FuncsToProcess;
141+
for (auto &F : M.functions()) {
142+
FuncsToProcess.push_back(&F);
143+
}
144+
145+
for (auto *F : FuncsToProcess) {
146+
llvm::SmallVector<std::pair<unsigned int, llvm::Type *>, 4> ArgUpdates;
147+
llvm::Type *RetType = nullptr;
148+
std::string MuxFuncNameToCall;
149+
if (!FunctionNeedsFixing(*F, ArgUpdates, RetType, MuxFuncNameToCall)) {
150+
continue;
151+
}
152+
if (!F->isDeclaration()) {
153+
continue;
154+
}
155+
Changed = true;
156+
IRBuilder<> IR(BasicBlock::Create(F->getContext(), "", F));
157+
158+
llvm::SmallVector<Type *, 8> Args;
159+
unsigned int ArgIndex = 0;
160+
unsigned int UpdateIndex = 0;
161+
162+
for (auto &Arg : F->args()) {
163+
if (!Arg.hasStructRetAttr()) {
164+
if (UpdateIndex < ArgUpdates.size() &&
165+
std::get<0>(ArgUpdates[UpdateIndex]) == ArgIndex) {
166+
Args.push_back(std::get<1>(ArgUpdates[UpdateIndex]));
167+
UpdateIndex++;
168+
} else {
169+
Args.push_back(Arg.getType());
170+
}
171+
}
172+
ArgIndex++;
173+
}
174+
175+
FunctionType *FT = FunctionType::get(RetType, Args, false);
176+
Function *NewFunc =
177+
Function::Create(FT, F->getLinkage(), MuxFuncNameToCall, M);
178+
llvm::SmallVector<Value *, 8> CallArgs;
179+
auto NewFuncArgItr = NewFunc->args().begin();
180+
Argument *SretPtr = nullptr;
181+
for (auto &Arg : F->args()) {
182+
if (Arg.hasStructRetAttr()) {
183+
SretPtr = &Arg;
184+
} else {
185+
if (Arg.getType() != (*NewFuncArgItr).getType()) {
186+
if (Arg.getType()->isPointerTy()) {
187+
Value *ArgLoad = IR.CreateLoad((*NewFuncArgItr).getType(), &Arg);
188+
CallArgs.push_back(ArgLoad);
189+
} else {
190+
Value *ArgCast = IR.CreateBitCast(&Arg, (*NewFuncArgItr).getType());
191+
CallArgs.push_back(ArgCast);
192+
}
193+
} else {
194+
CallArgs.push_back(&Arg);
195+
}
196+
NewFuncArgItr++;
197+
}
198+
}
199+
200+
Value *Res = IR.CreateCall(NewFunc, CallArgs);
201+
// If the return type is different to the initial function, then bitcast it
202+
// unless it's void in which case we'd expect an StructRet parameter which
203+
// needs stored to.
204+
if (F->getReturnType() != RetType) {
205+
if (F->getReturnType()->isVoidTy()) {
206+
// If we don't have an StructRet parameter then something is wrong with
207+
// the initial function
208+
if (!SretPtr) {
209+
llvm_unreachable(
210+
"No struct ret pointer for Sub group shuffle function");
211+
}
212+
213+
IR.CreateStore(Res, SretPtr);
214+
} else {
215+
Res = IR.CreateBitCast(Res, F->getReturnType());
216+
}
217+
}
218+
if (F->getReturnType()->isVoidTy()) {
219+
IR.CreateRetVoid();
220+
} else {
221+
IR.CreateRet(Res);
222+
}
223+
}
224+
225+
return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
226+
}

llvm/lib/SYCLNativeCPUUtils/PipelineSYCLNativeCPU.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414
#include "llvm/SYCLLowerIR/ConvertToMuxBuiltinsSYCLNativeCPU.h"
15+
#include "llvm/SYCLLowerIR/FixABIMuxBuiltinsSYCLNativeCPU.h"
1516
#include "llvm/SYCLLowerIR/PrepareSYCLNativeCPU.h"
1617
#include "llvm/SYCLLowerIR/RenameKernelSYCLNativeCPU.h"
1718
#include "llvm/SYCLLowerIR/SpecConstants.h"
@@ -65,6 +66,7 @@ void llvm::sycl::utils::addSYCLNativeCPUBackendPasses(
6566
MPM.addPass(ConvertToMuxBuiltinsSYCLNativeCPUPass());
6667
#ifdef NATIVECPU_USE_OCK
6768
MPM.addPass(compiler::utils::TransferKernelMetadataPass());
69+
MPM.addPass(FixABIMuxBuiltinsPass());
6870
// Always enable vectorizer, unless explictly disabled or -O0 is set.
6971
if (OptLevel != OptimizationLevel::O0 && !SYCLNativeCPUNoVecz) {
7072
MAM.registerPass([] { return vecz::TargetInfoAnalysis(); });

llvm/lib/SYCLNativeCPUUtils/PrepareSYCLNativeCPU.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -464,13 +464,17 @@ PreservedAnalyses PrepareSYCLNativeCPUPass::run(Module &M,
464464
F->eraseFromParent();
465465
ModuleChanged = true;
466466
}
467-
for (auto It = M.begin(); It != M.end();) {
468-
auto Curr = It++;
469-
Function &F = *Curr;
470-
if (F.getNumUses() == 0 && F.isDeclaration() &&
471-
F.getName().starts_with("__mux_")) {
472-
F.eraseFromParent();
473-
ModuleChanged = true;
467+
468+
// We do these twice because we create abi wrappers for mux which may show up
469+
// before we've removed their user
470+
for (unsigned int i = 0; i < 2; i++) {
471+
for (auto It = M.begin(); It != M.end();) {
472+
auto Curr = It++;
473+
Function &F = *Curr;
474+
if (F.getNumUses() == 0 && F.getName().starts_with("__mux_")) {
475+
F.eraseFromParent();
476+
ModuleChanged = true;
477+
}
474478
}
475479
}
476480

0 commit comments

Comments
 (0)