Skip to content

Commit bcbbf3f

Browse files
bokrzesiigcbot
authored andcommitted
Add handling for GEPs when trying to find TargetExtensionType of Opaque Pointer
This patch adds missing case of GetElementPtrInst when trying to figure out the TargetExtensionType of opaque pointer
1 parent 78f1922 commit bcbbf3f

File tree

3 files changed

+70
-16
lines changed

3 files changed

+70
-16
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,12 +1312,12 @@ Type *JointMatrixFuncsResolutionPass::TryFindTargetExtensionTypeOfOpaquePtr(Valu
13121312
auto aiTy = ai->getAllocatedType();
13131313
if (IGCLLVM::isTargetExtTy(aiTy))
13141314
return aiTy;
1315-
} else if (auto *ci = dyn_cast<CallInst>(use)) {
1316-
auto funcReturnType = ci->getFunction()->getReturnType();
1317-
if (IGCLLVM::isTargetExtTy(funcReturnType))
1318-
return funcReturnType;
1319-
} else if (auto *spaceCast = dyn_cast<AddrSpaceCastInst>(use)) {
1320-
return TryFindTargetExtensionTypeOfOpaquePtr(spaceCast->getPointerOperand());
1315+
} else if (auto *cast = dyn_cast<CastInst>(use)) {
1316+
return TryFindTargetExtensionTypeOfOpaquePtr(cast->getOperand(0));
1317+
} else if (auto *gep = dyn_cast<GetElementPtrInst>(use)) {
1318+
auto gepTy = gep->getResultElementType();
1319+
if (IGCLLVM::isTargetExtTy(gepTy))
1320+
return gepTy;
13211321
}
13221322
}
13231323

@@ -1334,11 +1334,8 @@ Type *JointMatrixFuncsResolutionPass::TryFindTypeOfOpaquePtr(Value *V) {
13341334
if (auto *ai = dyn_cast<AllocaInst>(use)) {
13351335
auto aiTy = ai->getAllocatedType();
13361336
return aiTy;
1337-
} else if (auto *ci = dyn_cast<CallInst>(use)) {
1338-
auto funcReturnType = ci->getFunction()->getReturnType();
1339-
return funcReturnType;
1340-
} else if (auto *spaceCast = dyn_cast<AddrSpaceCastInst>(use)) {
1341-
return TryFindTypeOfOpaquePtr(spaceCast->getPointerOperand());
1337+
} else if (auto *cast = dyn_cast<CastInst>(use)) {
1338+
return TryFindTypeOfOpaquePtr(cast->getOperand(0));
13421339
} else if (auto *gep = dyn_cast<GetElementPtrInst>(use)) {
13431340
return gep->getResultElementType();
13441341
} else if (auto *bitcast = dyn_cast<BitCastInst>(use)) {
@@ -1470,7 +1467,10 @@ Instruction *JointMatrixFuncsResolutionPass::ResolvePrefetch(CallInst *CI) {
14701467
ptrElemType = IGCLLVM::getNonOpaquePtrEltTy(ptrType);
14711468
}
14721469

1473-
IGC_ASSERT_MESSAGE(ptrElemType, "Pointer type not found");
1470+
if (!ptrElemType) {
1471+
m_Ctx->EmitError("Pointer type not found when resolving prefetch", ptrVal);
1472+
return nullptr;
1473+
}
14741474

14751475
if (StructType *structTy = dyn_cast<StructType>(ptrElemType)) {
14761476
if (structTy->getNumElements() == 1) {
@@ -2288,8 +2288,11 @@ bool JointMatrixFuncsResolutionPass::preprocessAccessChain(Function *F) {
22882288
#if LLVM_VERSION_MAJOR >= 16
22892289
if (IGCLLVM::isOpaquePointerTy(operand0->getType())) {
22902290
chainBaseTy = TryFindTargetExtensionTypeOfOpaquePtr(operand0);
2291-
IGC_ASSERT_MESSAGE(chainBaseTy, "__spirv_AccessChain call 1st argument must be "
2292-
"pointer to target extension type.");
2291+
2292+
if (!chainBaseTy) {
2293+
m_Ctx->EmitError("__spirv_AccessChain call 1st argument must be pointer to target extension type", operand0);
2294+
continue;
2295+
}
22932296
} else
22942297
#endif
22952298
{
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
;=========================== begin_copyright_notice ============================
2+
;
3+
; Copyright (C) 2025 Intel Corporation
4+
;
5+
; SPDX-License-Identifier: MIT
6+
;
7+
;============================ end_copyright_notice =============================
8+
; REQUIRES: llvm-16-plus
9+
10+
; RUN: igc_opt --opaque-pointers -platformpvc -igc-joint-matrix-resolution -S 2>&1 < %s | FileCheck %s
11+
; ------------------------------------------------
12+
; JointMatrixFuncsResolutionPass
13+
; ------------------------------------------------
14+
; Walk thru uses in order to figure out the TET type
15+
16+
%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix.3" = type { target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) }
17+
18+
; CHECK: define spir_kernel void @test()
19+
; CHECK: %addressCast = addrspacecast ptr %gep3 to ptr addrspace(4)
20+
; CHECK: %bitcast2 = bitcast ptr addrspace(4) %addressCast to ptr addrspace(4)
21+
; CHECK-NEXT: %0 = load <8 x i16>, ptr addrspace(4) %bitcast2, align 8
22+
23+
; Resolution flow: Call operand0 -> Bitcast -> Address Space Cast -> GEP -> TET
24+
25+
; Function Attrs: nounwind
26+
define spir_kernel void @test() {
27+
entry:
28+
%alloca = alloca target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0), align 8
29+
%gep = getelementptr inbounds %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix.3", ptr %alloca, i64 0, i32 0
30+
%addressCast = addrspacecast ptr %gep to ptr addrspace(4)
31+
%bitcast2 = bitcast ptr addrspace(4) %addressCast to ptr addrspace(4)
32+
%call = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainPU3AS4PU3AS144__spirv_CooperativeMatrixKHR__short_3_8_16_0l(ptr addrspace(4) %bitcast2, i64 0)
33+
34+
store i16 5, ptr addrspace(4) %call, align 2
35+
ret void
36+
}
37+
38+
; Function Attrs: nounwind
39+
declare spir_func ptr addrspace(4) @_Z19__spirv_AccessChainPU3AS4PU3AS144__spirv_CooperativeMatrixKHR__short_3_8_16_0l(ptr addrspace(4), i64) #0

IGC/Compiler/tests/JointMatrixFuncsResolutionPass/prefetch.ll

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66
;
77
;============================ end_copyright_notice =============================
88
; REQUIRES: llvm-16-plus
9-
; RUN: igc_opt --opaque-pointers -igc-joint-matrix-resolution --platformpvc -S 2>&1 < %s | FileCheck %s
9+
; RUN: igc_opt --opaque-pointers -igc-joint-matrix-resolution --platformpvc -S 2>&1 < %s | FileCheck %s --implicit-check-not error:
1010
; ------------------------------------------------
1111
; Written based on IR generated from this test: llvm/sycl/test-e2e/Matrix/joint_matrix_prefetch.cpp
1212
; The purpose of this test is to check whether we figure out
1313
; the type of pointer (operand 0) during ResolvePrefetch() correctly
1414
; ------------------------------------------------
1515

1616
; CHECK: call void @__builtin_spriv_OpJointMatrixPrefetchINTEL_SG16_8x16_i16(ptr addrspace(4) %add.ptr
17-
; CHECK-NOT: error
1817

1918
%"class.sycl::_V1::ext::oneapi::bfloat16" = type { i16 }
2019
; Function Attrs: nounwind
@@ -27,5 +26,18 @@ define spir_kernel void @test(ptr addrspace(1) align 2 %A, i64 %mul1) {
2726
ret void
2827
}
2928

29+
; CHECK: define spir_kernel void @test2(ptr addrspace(4) align 2 %A) {
30+
; CHECK: %bitcast2 = bitcast ptr addrspace(4) %gep to ptr addrspace(4)
31+
; CHECK: call void @__builtin_spriv_OpJointMatrixPrefetchINTEL_SG16_8x32_i8(ptr addrspace(4) %bitcast2, i64 256, i32 4)
32+
33+
; Comes from 'test-e2e/Matrix/Output/joint_matrix_bf16_fill_k_cache_prefetch.cpp'
34+
define spir_kernel void @test2(ptr addrspace(4) align 2 %A) {
35+
entry:
36+
%gep = getelementptr inbounds i8, ptr addrspace(4) %A, i64 1
37+
%bitcast2 = bitcast ptr addrspace(4) %gep to ptr addrspace(4)
38+
call spir_func void @"_Z38__spirv_CooperativeMatrixPrefetchINTELPU3AS438class.sycl::_V1::ext::oneapi::bfloat16iiiil"(ptr addrspace(4) %bitcast2, i32 8, i32 32, i32 0, i32 0, i64 256) #0
39+
ret void
40+
}
41+
3042
; Function Attrs: nounwind
3143
declare spir_func void @"_Z38__spirv_CooperativeMatrixPrefetchINTELPU3AS438class.sycl::_V1::ext::oneapi::bfloat16iiiil"(ptr addrspace(4) %0, i32 %1, i32 %2, i32 %3, i32 %4, i64 %5) #0

0 commit comments

Comments
 (0)