Skip to content

Commit 10d81ca

Browse files
bokrzesiigcbot
authored andcommitted
[LLVM16] Adding support for OpaquePointers in JointMatrixFuncsResolutionPass
This PR adds introduces handling of opaque pointers in LLVM 16. While for the majority of the JointMatrix code such transition was relatively easy, then for some specific cases we needed to put more effort. e.g When argument to the function was provided as an opaque pointer, but it was actually TET/Matrix Type Then we need to traverse up in order to figure out it's actual type, so we can resolve it correctly. Such scenario occured when processing access chain.
1 parent 5542ef5 commit 10d81ca

File tree

44 files changed

+1532
-1473
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1532
-1473
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.cpp

Lines changed: 145 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,12 @@ static bool isSupprtedLargeSlice(const JointMatrixTypeDescription *desc, bool us
738738
return false;
739739
}
740740

741+
bool JointMatrixFuncsResolutionPass::ValidateIntegerBitWidth(
742+
unsigned int bitWidth) {
743+
bool result = bitWidth == 8 || bitWidth == 16 || bitWidth == 32;
744+
return result;
745+
}
746+
741747
// TODO: Currently this function doesn't take into account large slices, when reporting
742748
// supported parameters. This should be fixed.
743749
bool JointMatrixFuncsResolutionPass::ValidateLoadStore
@@ -1023,12 +1029,9 @@ bool JointMatrixFuncsResolutionPass::parseMatrixTypeNameLegacy(const Type *opaqu
10231029
offset += 1; /* Skip type specifier, [f|i] */
10241030
outDescription->bitWidth = parseNumber(name, &offset);
10251031

1026-
bool supportedBitWidth =
1027-
(outDescription->bitWidth == 8 ||
1028-
outDescription->bitWidth == 16 ||
1029-
outDescription->bitWidth == 32);
1030-
IGC_ASSERT_MESSAGE(supportedBitWidth,
1031-
"Unexpected matrix element size.");
1032+
bool isBitWidthSupported =
1033+
ValidateIntegerBitWidth(outDescription->bitWidth);
1034+
IGC_ASSERT_MESSAGE(isBitWidthSupported, "Unexpected matrix element size.");
10321035

10331036
return true;
10341037
}
@@ -1123,12 +1126,9 @@ bool IGC::JointMatrixFuncsResolutionPass::ParseMatrixTypeNameExtTypeDetails(Type
11231126
if (typeParam->isIntegerTy()) {
11241127
outDescription->bitWidth = typeParam->getIntegerBitWidth();
11251128

1126-
if (outDescription->bitWidth != 8 &&
1127-
outDescription->bitWidth != 16 &&
1128-
outDescription->bitWidth != 32)
1129-
{
1130-
std::string msg = "Unexpected Matrix integer size: '"
1131-
+ std::to_string(outDescription->bitWidth) + "'. Only integers of size 8/16/32 are allowed.";
1129+
if (!ValidateIntegerBitWidth(outDescription->bitWidth)) {
1130+
std::string msg = "Unexpected Matrix integer size: '" +
1131+
std::to_string(outDescription->bitWidth) + "'.";
11321132
LLVM_DEBUG(dbgs() << msg << "\n");
11331133
m_Ctx->EmitError(msg.c_str(), nullptr);
11341134
return false;
@@ -1377,11 +1377,77 @@ static bool isAccumulator32x32(const JointMatrixTypeDescription &desc) {
13771377
return (desc.layout == LayoutRowMajor && desc.rows == 32 && desc.columns == 32);
13781378
}
13791379

1380+
#if LLVM_VERSION_MAJOR >= 16
1381+
// When we alloca target extension type, later it is "ptr".
1382+
// We need to figure out the type of "ptr" by traversing up.
1383+
// Example:
1384+
// %0 = alloca target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0)
1385+
// %ptr = call spir_func ptr
1386+
// @_Z19__spirv_AccessChainPU3AS4PU3AS143__spirv_CooperativeMatrixKHR(ptr %0,
1387+
// i64 4)
1388+
Type *JointMatrixFuncsResolutionPass::TryFindTargetExtensionTypeOfOpaquePtr(
1389+
Value *V) {
1390+
if (!V)
1391+
return nullptr;
1392+
1393+
for (auto &use : V->uses()) {
1394+
if (auto *ai = dyn_cast<AllocaInst>(use)) {
1395+
auto aiTy = ai->getAllocatedType();
1396+
if (IGCLLVM::isTargetExtTy(aiTy))
1397+
return aiTy;
1398+
} else if (auto *ci = dyn_cast<CallInst>(use)) {
1399+
auto funcReturnType = ci->getFunction()->getReturnType();
1400+
if (IGCLLVM::isTargetExtTy(funcReturnType))
1401+
return funcReturnType;
1402+
} else if (auto *spaceCast = dyn_cast<AddrSpaceCastInst>(use)) {
1403+
return TryFindTargetExtensionTypeOfOpaquePtr(
1404+
spaceCast->getPointerOperand());
1405+
}
1406+
}
1407+
1408+
return nullptr;
1409+
}
1410+
// When resolving e.g prefetch we need to figure out type of ptr
1411+
// We can do it by traversing up.
1412+
// It is similar to approach TryFindTargetExtensionTypeOfOpaquePtr
1413+
Type *JointMatrixFuncsResolutionPass::TryFindTypeOfOpaquePtr(Value *V) {
1414+
if (!V)
1415+
return nullptr;
1416+
1417+
for (auto &use : V->uses()) {
1418+
if (auto *ai = dyn_cast<AllocaInst>(use)) {
1419+
auto aiTy = ai->getAllocatedType();
1420+
return aiTy;
1421+
} else if (auto *ci = dyn_cast<CallInst>(use)) {
1422+
auto funcReturnType = ci->getFunction()->getReturnType();
1423+
return funcReturnType;
1424+
} else if (auto *spaceCast = dyn_cast<AddrSpaceCastInst>(use)) {
1425+
return TryFindTypeOfOpaquePtr(spaceCast->getPointerOperand());
1426+
} else if (auto *gep = dyn_cast<GetElementPtrInst>(use)) {
1427+
return gep->getResultElementType();
1428+
} else if (auto *bitcast = dyn_cast<BitCastInst>(use)) {
1429+
if (!IGCLLVM::isOpaquePointerTy(bitcast->getSrcTy()))
1430+
return bitcast->getSrcTy();
1431+
1432+
return TryFindTypeOfOpaquePtr(bitcast->getOperand(0));
1433+
}
1434+
}
1435+
1436+
return nullptr;
1437+
}
1438+
#endif
1439+
13801440
Type *JointMatrixFuncsResolutionPass::ResolveType(Type *inputType, JointMatrixTypeDescription *outDesc)
13811441
{
13821442
IGC_ASSERT_EXIT_MESSAGE(inputType && (inputType->isPointerTy() || IGCLLVM::isTargetExtTy(inputType)),
13831443
"Unexpected type in matrix function resolution.");
13841444

1445+
#if LLVM_VERSION_MAJOR >= 16
1446+
IGC_ASSERT_EXIT_MESSAGE(
1447+
!IGCLLVM::isOpaquePointerTy(inputType),
1448+
"Unexpected opaque pointer. Expected TargetExtensionType instead.");
1449+
#endif
1450+
13851451
JointMatrixTypeDescription desc;
13861452
bool parseResult = ParseMatrixTypeName(inputType, &desc);
13871453
IGC_ASSERT_EXIT_MESSAGE(parseResult, "Failed to parse matrix type.");
@@ -1484,32 +1550,52 @@ Instruction *JointMatrixFuncsResolutionPass::ResolvePrefetch(CallInst *CI)
14841550
desc.columns = (unsigned)constIntValue(numColsVal);
14851551

14861552
// Pointer type resolution
1553+
Type *ptrElemType = nullptr;
1554+
1555+
#if LLVM_VERSION_MAJOR >= 16
1556+
if (IGCLLVM::isOpaquePointerTy(ptrVal->getType()))
1557+
ptrElemType = TryFindTypeOfOpaquePtr(ptrVal);
1558+
else
1559+
#endif
14871560
{
1488-
PointerType *ptrType = cast<PointerType>(ptrVal->getType());
1489-
Type *ptrElemType = IGCLLVM::getNonOpaquePtrEltTy(ptrType);
1490-
1491-
if (StructType *structTy = dyn_cast<StructType>(ptrElemType)) {
1492-
if (structTy->getNumElements() == 1) {
1493-
ptrElemType = structTy->getElementType(0);
1494-
// we assume that only custom floating point types are wrapped into structs
1495-
desc.isFloating = true;
1496-
}
1497-
}
1561+
// To be removed after switch to LLVM 16 + full opaque pointers enablement
1562+
PointerType *ptrType = cast<PointerType>(ptrVal->getType());
1563+
ptrElemType = IGCLLVM::getNonOpaquePtrEltTy(ptrType);
1564+
}
14981565

1499-
if (ptrElemType->isHalfTy()) {
1500-
desc.bitWidth = 16;
1501-
desc.isFloating = true;
1502-
} else if (ptrElemType->isFloatTy()) {
1503-
desc.bitWidth = 32;
1504-
desc.isFloating = true;
1505-
} else if (ptrElemType->isDoubleTy()) {
1506-
desc.bitWidth = 64;
1507-
desc.isFloating = true;
1508-
} else if (ptrElemType->isIntegerTy()) {
1509-
desc.bitWidth = cast<IntegerType>(ptrElemType)->getBitWidth();
1510-
} else {
1511-
m_Ctx->EmitError("Failed to resolve matrix prefetch pointer type", ptrVal);
1512-
}
1566+
IGC_ASSERT_MESSAGE(ptrElemType, "Pointer type not found");
1567+
1568+
if (StructType *structTy = dyn_cast<StructType>(ptrElemType)) {
1569+
if (structTy->getNumElements() == 1) {
1570+
ptrElemType = structTy->getElementType(0);
1571+
// we assume that only custom floating point types are wrapped into
1572+
// structs
1573+
desc.isFloating = true;
1574+
}
1575+
}
1576+
1577+
if (ptrElemType->isHalfTy()) {
1578+
desc.bitWidth = 16;
1579+
desc.isFloating = true;
1580+
} else if (ptrElemType->isFloatTy()) {
1581+
desc.bitWidth = 32;
1582+
desc.isFloating = true;
1583+
} else if (ptrElemType->isDoubleTy()) {
1584+
desc.bitWidth = 64;
1585+
desc.isFloating = true;
1586+
} else if (ptrElemType->isIntegerTy()) {
1587+
desc.bitWidth = cast<IntegerType>(ptrElemType)->getBitWidth();
1588+
1589+
if (!ValidateIntegerBitWidth(desc.bitWidth)) {
1590+
std::string msg = "Unexpected Matrix integer size: '" +
1591+
std::to_string(desc.bitWidth) + "'.";
1592+
LLVM_DEBUG(dbgs() << msg << "\n");
1593+
m_Ctx->EmitError(msg.c_str(), ptrVal);
1594+
return nullptr;
1595+
}
1596+
} else {
1597+
m_Ctx->EmitError("Failed to resolve matrix prefetch pointer type",
1598+
ptrVal);
15131599
}
15141600

15151601
LLVMContext &ctx = CI->getContext();
@@ -2268,13 +2354,32 @@ bool JointMatrixFuncsResolutionPass::preprocessAccessChain(Function *F) {
22682354
continue;
22692355
}
22702356

2357+
#if LLVM_VERSION_MAJOR < 16
22712358
if(IGCLLVM::isOpaquePointerTy(CI->getArgOperand(0)->getType()))
22722359
continue;
2360+
#endif
22732361

22742362
LLVM_DEBUG(dbgs() << " - PREPROCESS ACCESS CHAIN: " << *CI << "\n");
22752363

2276-
Type *chainBaseTy =
2277-
IGCLLVM::getNonOpaquePtrEltTy(CI->getArgOperand(0)->getType());
2364+
Type *chainBaseTy = nullptr;
2365+
auto operand0 = CI->getArgOperand(0);
2366+
2367+
#if LLVM_VERSION_MAJOR >= 16
2368+
if (IGCLLVM::isOpaquePointerTy(operand0->getType())) {
2369+
chainBaseTy = TryFindTargetExtensionTypeOfOpaquePtr(operand0);
2370+
IGC_ASSERT_MESSAGE(chainBaseTy,
2371+
"__spirv_AccessChain call 1st argument must be "
2372+
"pointer to target extension type.");
2373+
} else
2374+
#endif
2375+
{
2376+
// to be removed after we switch to LLVM 16 with opaque pointers by
2377+
// default
2378+
chainBaseTy = IGCLLVM::getNonOpaquePtrEltTy(operand0->getType());
2379+
IGC_ASSERT_MESSAGE(
2380+
chainBaseTy, "__spirv_AccessChain call 1st argument is invalid");
2381+
}
2382+
22782383
IGC_ASSERT_MESSAGE(isMatrixType(chainBaseTy),
22792384
"__spirv_AccessChain call 1st argument must be cooperative matrix");
22802385
Value *ptrToMatrix = CI->getArgOperand(0);
@@ -2755,8 +2860,10 @@ void JointMatrixFuncsResolutionPass::visitCallInst(CallInst& CI)
27552860

27562861
StringRef funcName = func->getName();
27572862

2863+
#if LLVM_VERSION_MAJOR < 16
27582864
if(IGCLLVM::isOpaquePointerTy(CI.getType()) || isAnyOperand(CI, IGCLLVM::isOpaquePointerTy))
27592865
return;
2866+
#endif
27602867

27612868
/* Resolve calls to JointMatrix BIs that haven't been resolved yet. In
27622869
* future when returning and passing matrices by argument is

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ namespace IGC
8787
bool IsJointMatrix,
8888
JointMatrixTypeDescription* outDescription);
8989
#if LLVM_VERSION_MAJOR >= 16
90+
llvm::Type *TryFindTargetExtensionTypeOfOpaquePtr(llvm::Value *V);
91+
llvm::Type *TryFindTypeOfOpaquePtr(llvm::Value *V);
9092
bool ParseMatrixTypeNameExtTypeDetails(llvm::Type* opaqueType, bool IsJointMatrix, IGC::JointMatrixTypeDescription* outDescription);
9193
#endif
9294

@@ -122,6 +124,7 @@ namespace IGC
122124
const JointMatrixTypeDescription *desc,
123125
const std::string& prefix);
124126

127+
bool ValidateIntegerBitWidth(unsigned int width);
125128
bool ValidateLoadStore
126129
(bool isLoad, unsigned operationLayout, const JointMatrixTypeDescription *desc, llvm::Value *ctx);
127130

IGC/Compiler/tests/EmitVISAPass/vectorizer-vector-emission-exp2.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
; UNSUPPORTED: system-windows
2-
; REQUIRES: regkeys
2+
; REQUIRES: pvc-supported, regkeys
33

44
; RUN: igc_opt -S -dce -platformpvc -rev-id B -has-emulated-64-bit-insts -igc-emit-visa --regkey=DumpVISAASMToConsole=1 -simd-mode 16 < %s | FileCheck %s
55

IGC/Compiler/tests/EmitVISAPass/vectorizer-vector-emission-fadd.ll

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
; REQUIRES: regkeys
1+
; REQUIRES: pvc-supported, regkeys
22

33
; RUN: igc_opt -S -dce -platformpvc -rev-id B -has-emulated-64-bit-insts -igc-emit-visa --regkey=DumpVISAASMToConsole=1 -simd-mode 16 < %s | FileCheck %s
44

5-
; CHECK: .decl tmp15 v_type=G type=f num_elts=128 align=wordx32
6-
; CHECK: .decl tmp16 v_type=G type=f num_elts=8 align=dword
5+
; CHECK: .decl vectorized_phi v_type=G type=f num_elts=128 align=wordx32
6+
; CHECK: .decl vector v_type=G type=f num_elts=8 align=dword
77

8-
; CHECK: add (M1, 16) tmp15(0,0)<1> tmp16(0,0)<0;1,0> tmp15(0,0)<1;1,0>
9-
; CHECK: add (M1, 16) tmp15(1,0)<1> tmp16(0,1)<0;1,0> tmp15(1,0)<1;1,0>
10-
; CHECK: add (M1, 16) tmp15(2,0)<1> tmp16(0,2)<0;1,0> tmp15(2,0)<1;1,0>
11-
; CHECK: add (M1, 16) tmp15(3,0)<1> tmp16(0,3)<0;1,0> tmp15(3,0)<1;1,0>
12-
; CHECK: add (M1, 16) tmp15(4,0)<1> tmp16(0,4)<0;1,0> tmp15(4,0)<1;1,0>
13-
; CHECK: add (M1, 16) tmp15(5,0)<1> tmp16(0,5)<0;1,0> tmp15(5,0)<1;1,0>
14-
; CHECK: add (M1, 16) tmp15(6,0)<1> tmp16(0,6)<0;1,0> tmp15(6,0)<1;1,0>
15-
; CHECK: add (M1, 16) tmp15(7,0)<1> tmp16(0,7)<0;1,0> tmp15(7,0)<1;1,0>
8+
; CHECK: add (M1, 16) vectorized_phi(0,0)<1> vector(0,0)<0;1,0> vectorized_phi(0,0)<1;1,0>
9+
; CHECK: add (M1, 16) vectorized_phi(1,0)<1> vector(0,1)<0;1,0> vectorized_phi(1,0)<1;1,0>
10+
; CHECK: add (M1, 16) vectorized_phi(2,0)<1> vector(0,2)<0;1,0> vectorized_phi(2,0)<1;1,0>
11+
; CHECK: add (M1, 16) vectorized_phi(3,0)<1> vector(0,3)<0;1,0> vectorized_phi(3,0)<1;1,0>
12+
; CHECK: add (M1, 16) vectorized_phi(4,0)<1> vector(0,4)<0;1,0> vectorized_phi(4,0)<1;1,0>
13+
; CHECK: add (M1, 16) vectorized_phi(5,0)<1> vector(0,5)<0;1,0> vectorized_phi(5,0)<1;1,0>
14+
; CHECK: add (M1, 16) vectorized_phi(6,0)<1> vector(0,6)<0;1,0> vectorized_phi(6,0)<1;1,0>
15+
; CHECK: add (M1, 16) vectorized_phi(7,0)<1> vector(0,7)<0;1,0> vectorized_phi(7,0)<1;1,0>
1616

1717
define spir_kernel void @blam(half addrspace(1)* %arg, half addrspace(1)* %arg1, half addrspace(1)* %arg2, float %arg3, i8 addrspace(1)* %arg4, float addrspace(1)* %arg5, <8 x i32> %arg6, <8 x i32> %arg7, i8* %arg8, i32 %arg9, i32 %arg10, i32 %arg11, i32 %arg12, i32 %arg13) {
1818
bb:

0 commit comments

Comments
 (0)