@@ -738,6 +738,12 @@ static bool isSupprtedLargeSlice(const JointMatrixTypeDescription *desc, bool us
738738 return false ;
739739}
740740
741+ bool JointMatrixFuncsResolutionPass::ValidateIntegerBitWidth (
742+ unsigned int bitWidth) {
743+ bool result = bitWidth == 8 || bitWidth == 16 || bitWidth == 32 ;
744+ return result;
745+ }
746+
741747// TODO: Currently this function doesn't take into account large slices, when reporting
742748// supported parameters. This should be fixed.
743749bool JointMatrixFuncsResolutionPass::ValidateLoadStore
@@ -1023,12 +1029,9 @@ bool JointMatrixFuncsResolutionPass::parseMatrixTypeNameLegacy(const Type *opaqu
10231029 offset += 1 ; /* Skip type specifier, [f|i] */
10241030 outDescription->bitWidth = parseNumber (name, &offset);
10251031
1026- bool supportedBitWidth =
1027- (outDescription->bitWidth == 8 ||
1028- outDescription->bitWidth == 16 ||
1029- outDescription->bitWidth == 32 );
1030- IGC_ASSERT_MESSAGE (supportedBitWidth,
1031- " Unexpected matrix element size." );
1032+ bool isBitWidthSupported =
1033+ ValidateIntegerBitWidth (outDescription->bitWidth );
1034+ IGC_ASSERT_MESSAGE (isBitWidthSupported, " Unexpected matrix element size." );
10321035
10331036 return true ;
10341037}
@@ -1123,12 +1126,9 @@ bool IGC::JointMatrixFuncsResolutionPass::ParseMatrixTypeNameExtTypeDetails(Type
11231126 if (typeParam->isIntegerTy ()) {
11241127 outDescription->bitWidth = typeParam->getIntegerBitWidth ();
11251128
1126- if (outDescription->bitWidth != 8 &&
1127- outDescription->bitWidth != 16 &&
1128- outDescription->bitWidth != 32 )
1129- {
1130- std::string msg = " Unexpected Matrix integer size: '"
1131- + std::to_string (outDescription->bitWidth ) + " '. Only integers of size 8/16/32 are allowed." ;
1129+ if (!ValidateIntegerBitWidth (outDescription->bitWidth )) {
1130+ std::string msg = " Unexpected Matrix integer size: '" +
1131+ std::to_string (outDescription->bitWidth ) + " '." ;
11321132 LLVM_DEBUG (dbgs () << msg << " \n " );
11331133 m_Ctx->EmitError (msg.c_str (), nullptr );
11341134 return false ;
@@ -1377,11 +1377,77 @@ static bool isAccumulator32x32(const JointMatrixTypeDescription &desc) {
13771377 return (desc.layout == LayoutRowMajor && desc.rows == 32 && desc.columns == 32 );
13781378}
13791379
1380+ #if LLVM_VERSION_MAJOR >= 16
1381+ // When we alloca target extension type, later it is "ptr".
1382+ // We need to figure out the type of "ptr" by traversing up.
1383+ // Example:
1384+ // %0 = alloca target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0)
1385+ // %ptr = call spir_func ptr
1386+ // @_Z19__spirv_AccessChainPU3AS4PU3AS143__spirv_CooperativeMatrixKHR(ptr %0,
1387+ // i64 4)
1388+ Type *JointMatrixFuncsResolutionPass::TryFindTargetExtensionTypeOfOpaquePtr (
1389+ Value *V) {
1390+ if (!V)
1391+ return nullptr ;
1392+
1393+ for (auto &use : V->uses ()) {
1394+ if (auto *ai = dyn_cast<AllocaInst>(use)) {
1395+ auto aiTy = ai->getAllocatedType ();
1396+ if (IGCLLVM::isTargetExtTy (aiTy))
1397+ return aiTy;
1398+ } else if (auto *ci = dyn_cast<CallInst>(use)) {
1399+ auto funcReturnType = ci->getFunction ()->getReturnType ();
1400+ if (IGCLLVM::isTargetExtTy (funcReturnType))
1401+ return funcReturnType;
1402+ } else if (auto *spaceCast = dyn_cast<AddrSpaceCastInst>(use)) {
1403+ return TryFindTargetExtensionTypeOfOpaquePtr (
1404+ spaceCast->getPointerOperand ());
1405+ }
1406+ }
1407+
1408+ return nullptr ;
1409+ }
1410+ // When resolving e.g prefetch we need to figure out type of ptr
1411+ // We can do it by traversing up.
1412+ // It is similar to approach TryFindTargetExtensionTypeOfOpaquePtr
1413+ Type *JointMatrixFuncsResolutionPass::TryFindTypeOfOpaquePtr (Value *V) {
1414+ if (!V)
1415+ return nullptr ;
1416+
1417+ for (auto &use : V->uses ()) {
1418+ if (auto *ai = dyn_cast<AllocaInst>(use)) {
1419+ auto aiTy = ai->getAllocatedType ();
1420+ return aiTy;
1421+ } else if (auto *ci = dyn_cast<CallInst>(use)) {
1422+ auto funcReturnType = ci->getFunction ()->getReturnType ();
1423+ return funcReturnType;
1424+ } else if (auto *spaceCast = dyn_cast<AddrSpaceCastInst>(use)) {
1425+ return TryFindTypeOfOpaquePtr (spaceCast->getPointerOperand ());
1426+ } else if (auto *gep = dyn_cast<GetElementPtrInst>(use)) {
1427+ return gep->getResultElementType ();
1428+ } else if (auto *bitcast = dyn_cast<BitCastInst>(use)) {
1429+ if (!IGCLLVM::isOpaquePointerTy (bitcast->getSrcTy ()))
1430+ return bitcast->getSrcTy ();
1431+
1432+ return TryFindTypeOfOpaquePtr (bitcast->getOperand (0 ));
1433+ }
1434+ }
1435+
1436+ return nullptr ;
1437+ }
1438+ #endif
1439+
13801440Type *JointMatrixFuncsResolutionPass::ResolveType (Type *inputType, JointMatrixTypeDescription *outDesc)
13811441{
13821442 IGC_ASSERT_EXIT_MESSAGE (inputType && (inputType->isPointerTy () || IGCLLVM::isTargetExtTy (inputType)),
13831443 " Unexpected type in matrix function resolution." );
13841444
1445+ #if LLVM_VERSION_MAJOR >= 16
1446+ IGC_ASSERT_EXIT_MESSAGE (
1447+ !IGCLLVM::isOpaquePointerTy (inputType),
1448+ " Unexpected opaque pointer. Expected TargetExtensionType instead." );
1449+ #endif
1450+
13851451 JointMatrixTypeDescription desc;
13861452 bool parseResult = ParseMatrixTypeName (inputType, &desc);
13871453 IGC_ASSERT_EXIT_MESSAGE (parseResult, " Failed to parse matrix type." );
@@ -1484,32 +1550,52 @@ Instruction *JointMatrixFuncsResolutionPass::ResolvePrefetch(CallInst *CI)
14841550 desc.columns = (unsigned )constIntValue (numColsVal);
14851551
14861552 // Pointer type resolution
1553+ Type *ptrElemType = nullptr ;
1554+
1555+ #if LLVM_VERSION_MAJOR >= 16
1556+ if (IGCLLVM::isOpaquePointerTy (ptrVal->getType ()))
1557+ ptrElemType = TryFindTypeOfOpaquePtr (ptrVal);
1558+ else
1559+ #endif
14871560 {
1488- PointerType *ptrType = cast<PointerType>(ptrVal->getType ());
1489- Type *ptrElemType = IGCLLVM::getNonOpaquePtrEltTy (ptrType);
1490-
1491- if (StructType *structTy = dyn_cast<StructType>(ptrElemType)) {
1492- if (structTy->getNumElements () == 1 ) {
1493- ptrElemType = structTy->getElementType (0 );
1494- // we assume that only custom floating point types are wrapped into structs
1495- desc.isFloating = true ;
1496- }
1497- }
1561+ // To be removed after switch to LLVM 16 + full opaque pointers enablement
1562+ PointerType *ptrType = cast<PointerType>(ptrVal->getType ());
1563+ ptrElemType = IGCLLVM::getNonOpaquePtrEltTy (ptrType);
1564+ }
14981565
1499- if (ptrElemType->isHalfTy ()) {
1500- desc.bitWidth = 16 ;
1501- desc.isFloating = true ;
1502- } else if (ptrElemType->isFloatTy ()) {
1503- desc.bitWidth = 32 ;
1504- desc.isFloating = true ;
1505- } else if (ptrElemType->isDoubleTy ()) {
1506- desc.bitWidth = 64 ;
1507- desc.isFloating = true ;
1508- } else if (ptrElemType->isIntegerTy ()) {
1509- desc.bitWidth = cast<IntegerType>(ptrElemType)->getBitWidth ();
1510- } else {
1511- m_Ctx->EmitError (" Failed to resolve matrix prefetch pointer type" , ptrVal);
1512- }
1566+ IGC_ASSERT_MESSAGE (ptrElemType, " Pointer type not found" );
1567+
1568+ if (StructType *structTy = dyn_cast<StructType>(ptrElemType)) {
1569+ if (structTy->getNumElements () == 1 ) {
1570+ ptrElemType = structTy->getElementType (0 );
1571+ // we assume that only custom floating point types are wrapped into
1572+ // structs
1573+ desc.isFloating = true ;
1574+ }
1575+ }
1576+
1577+ if (ptrElemType->isHalfTy ()) {
1578+ desc.bitWidth = 16 ;
1579+ desc.isFloating = true ;
1580+ } else if (ptrElemType->isFloatTy ()) {
1581+ desc.bitWidth = 32 ;
1582+ desc.isFloating = true ;
1583+ } else if (ptrElemType->isDoubleTy ()) {
1584+ desc.bitWidth = 64 ;
1585+ desc.isFloating = true ;
1586+ } else if (ptrElemType->isIntegerTy ()) {
1587+ desc.bitWidth = cast<IntegerType>(ptrElemType)->getBitWidth ();
1588+
1589+ if (!ValidateIntegerBitWidth (desc.bitWidth )) {
1590+ std::string msg = " Unexpected Matrix integer size: '" +
1591+ std::to_string (desc.bitWidth ) + " '." ;
1592+ LLVM_DEBUG (dbgs () << msg << " \n " );
1593+ m_Ctx->EmitError (msg.c_str (), ptrVal);
1594+ return nullptr ;
1595+ }
1596+ } else {
1597+ m_Ctx->EmitError (" Failed to resolve matrix prefetch pointer type" ,
1598+ ptrVal);
15131599 }
15141600
15151601 LLVMContext &ctx = CI->getContext ();
@@ -2268,13 +2354,32 @@ bool JointMatrixFuncsResolutionPass::preprocessAccessChain(Function *F) {
22682354 continue ;
22692355 }
22702356
2357+ #if LLVM_VERSION_MAJOR < 16
22712358 if (IGCLLVM::isOpaquePointerTy (CI->getArgOperand (0 )->getType ()))
22722359 continue ;
2360+ #endif
22732361
22742362 LLVM_DEBUG (dbgs () << " - PREPROCESS ACCESS CHAIN: " << *CI << " \n " );
22752363
2276- Type *chainBaseTy =
2277- IGCLLVM::getNonOpaquePtrEltTy (CI->getArgOperand (0 )->getType ());
2364+ Type *chainBaseTy = nullptr ;
2365+ auto operand0 = CI->getArgOperand (0 );
2366+
2367+ #if LLVM_VERSION_MAJOR >= 16
2368+ if (IGCLLVM::isOpaquePointerTy (operand0->getType ())) {
2369+ chainBaseTy = TryFindTargetExtensionTypeOfOpaquePtr (operand0);
2370+ IGC_ASSERT_MESSAGE (chainBaseTy,
2371+ " __spirv_AccessChain call 1st argument must be "
2372+ " pointer to target extension type." );
2373+ } else
2374+ #endif
2375+ {
2376+ // to be removed after we switch to LLVM 16 with opaque pointers by
2377+ // default
2378+ chainBaseTy = IGCLLVM::getNonOpaquePtrEltTy (operand0->getType ());
2379+ IGC_ASSERT_MESSAGE (
2380+ chainBaseTy, " __spirv_AccessChain call 1st argument is invalid" );
2381+ }
2382+
22782383 IGC_ASSERT_MESSAGE (isMatrixType (chainBaseTy),
22792384 " __spirv_AccessChain call 1st argument must be cooperative matrix" );
22802385 Value *ptrToMatrix = CI->getArgOperand (0 );
@@ -2755,8 +2860,10 @@ void JointMatrixFuncsResolutionPass::visitCallInst(CallInst& CI)
27552860
27562861 StringRef funcName = func->getName ();
27572862
2863+ #if LLVM_VERSION_MAJOR < 16
27582864 if (IGCLLVM::isOpaquePointerTy (CI.getType ()) || isAnyOperand (CI, IGCLLVM::isOpaquePointerTy))
27592865 return ;
2866+ #endif
27602867
27612868 /* Resolve calls to JointMatrix BIs that haven't been resolved yet. In
27622869 * future when returning and passing matrices by argument is
0 commit comments