diff --git a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp index 4b968d5a9bbe1..629b27d61f24b 100644 --- a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp @@ -22,16 +22,43 @@ namespace { static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain"; static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR"; -// This routine extracts spirv.CooperativeMatrixKHR target extension type -// from sycl::joint_matrix class object if it's used in __spirv_AccessChain -// function call. It's necessary because otherwise OpAccessChain indices would -// be wrong. +// This function finds all calls to __spirv_AccessChain function and transforms +// its users and operands to make LLVM IR more SPIR-V friendly. bool transformAccessChain(Function *F) { bool ModuleChanged = false; for (auto I : F->users()) { auto *CI = dyn_cast(I); if (!CI) continue; + + // This is a W/A for bfloat16 and tf32 types - they are represented in SYCL + // as structures with int16/float storages. It means, that in LLVM IR + // user of CallInst to __spirv_AccessChain function would be not load/store + // instruction, but a zero GEP. This zero GEP is no-op, but can confuse a + // SPIR-V consumer, so lets remove it here. + auto *Unique = CI->getUniqueUndroppableUser(); + if (auto *GEP = dyn_cast_or_null(Unique)) { + if (GEP->hasAllZeroIndices()) { + GEP->replaceAllUsesWith(CI); + GEP->dropAllReferences(); + GEP->eraseFromParent(); + } + } + + // It can happen that the optimizer can remove duplicated or dead uses + // of CallInst to __spirv_AccessChain function. But it can't remove + // __spirv_AccessChain call itself as it's a call to external function. + // Lets clean such calls. + if (CI->getNumUses() == 0) { + CI->dropAllReferences(); + CI->eraseFromParent(); + continue; + } + + // This routine extracts spirv.CooperativeMatrixKHR target extension type + // from sycl::joint_matrix class object if it's used in __spirv_AccessChain + // function call. It's necessary because otherwise OpAccessChain indices + // would be wrong. Instruction *Ptr = dyn_cast(CI->getArgOperand(0)->stripPointerCasts()); if (!Ptr || !isa(Ptr)) diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/access-chain-no-uses.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/access-chain-no-uses.ll new file mode 100644 index 0000000000000..40f9272fbdf44 --- /dev/null +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/access-chain-no-uses.ll @@ -0,0 +1,22 @@ +; The test checks, that unused call to __spirv_AccessChain is eliminated. + +; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s + +; CHECK-NOT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain + +; ModuleID = 'test.bc' +source_filename = "test.cpp" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1" +target triple = "spir64-unknown-unknown" + +%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) } + +define weak_odr dso_local spir_kernel void @test() { +entry: + %0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8 + %1 = addrspacecast ptr %0 to ptr addrspace(4) + %2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0) + ret void +} + +declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef) diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll index d43b4a1e91e7a..5373938405717 100644 --- a/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll @@ -19,6 +19,7 @@ entry: %0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8 %1 = addrspacecast ptr %0 to ptr addrspace(4) %2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0) + %3 = load i8, ptr addrspace(4) %2 ret void } diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain_bf16.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain_bf16.ll new file mode 100644 index 0000000000000..11e7c53936610 --- /dev/null +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain_bf16.ll @@ -0,0 +1,27 @@ +; Test checks if useless zero GEP to get i16 from sycl::bfloat16 is being removed. + +; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s + +; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i16, 3, 16, 64, 0) +; CHECK: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4) +; CHECK: %[[#AC:]] = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0) +; CHECK: load i16, ptr addrspace(4) %[[#AC]] + +; ModuleID = 'test.bc' +source_filename = "test.cpp" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1" +target triple = "spir64-unknown-unknown" + +%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i16, 3, 16, 64, 0) } + +define weak_odr dso_local spir_kernel void @test() { +entry: + %0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8 + %1 = addrspacecast ptr %0 to ptr addrspace(4) + %2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0) + %3 = getelementptr inbounds { i16 }, ptr addrspace(4) %2, i64 0, i32 0 + %4 = load i16, ptr addrspace(4) %3 + ret void +} + +declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef)