@@ -22,16 +22,43 @@ namespace {
2222static constexpr char ACCESS_CHAIN[] = " _Z19__spirv_AccessChain" ;
2323static constexpr char MATRIX_TYPE[] = " spirv.CooperativeMatrixKHR" ;
2424
25- // This routine extracts spirv.CooperativeMatrixKHR target extension type
26- // from sycl::joint_matrix class object if it's used in __spirv_AccessChain
27- // function call. It's necessary because otherwise OpAccessChain indices would
28- // be wrong.
25+ // This function finds all calls to __spirv_AccessChain function and transforms
26+ // its users and operands to make LLVM IR more SPIR-V friendly.
2927bool transformAccessChain (Function *F) {
3028 bool ModuleChanged = false ;
3129 for (auto I : F->users ()) {
3230 auto *CI = dyn_cast<CallInst>(I);
3331 if (!CI)
3432 continue ;
33+
34+ // This is a W/A for bfloat16 and tf32 types - they are represented in SYCL
35+ // as structures with int16/float storages. It means, that in LLVM IR
36+ // user of CallInst to __spirv_AccessChain function would be not load/store
37+ // instruction, but a zero GEP. This zero GEP is no-op, but can confuse a
38+ // SPIR-V consumer, so lets remove it here.
39+ auto *Unique = CI->getUniqueUndroppableUser ();
40+ if (auto *GEP = dyn_cast_or_null<GetElementPtrInst>(Unique)) {
41+ if (GEP->hasAllZeroIndices ()) {
42+ GEP->replaceAllUsesWith (CI);
43+ GEP->dropAllReferences ();
44+ GEP->eraseFromParent ();
45+ }
46+ }
47+
48+ // It can happen that the optimizer can remove duplicated or dead uses
49+ // of CallInst to __spirv_AccessChain function. But it can't remove
50+ // __spirv_AccessChain call itself as it's a call to external function.
51+ // Lets clean such calls.
52+ if (CI->getNumUses () == 0 ) {
53+ CI->dropAllReferences ();
54+ CI->eraseFromParent ();
55+ continue ;
56+ }
57+
58+ // This routine extracts spirv.CooperativeMatrixKHR target extension type
59+ // from sycl::joint_matrix class object if it's used in __spirv_AccessChain
60+ // function call. It's necessary because otherwise OpAccessChain indices
61+ // would be wrong.
3562 Instruction *Ptr =
3663 dyn_cast<Instruction>(CI->getArgOperand (0 )->stripPointerCasts ());
3764 if (!Ptr || !isa<AllocaInst>(Ptr))
0 commit comments