diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp index 682fca7cc7747..ecbceb5b472fa 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -223,10 +223,10 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI, doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType); } -static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI, - MachineRegisterInfo *MRI, - SPIRVGlobalRegistry &GR, MachineInstr &I, - unsigned OpIdx) { +static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI, + MachineRegisterInfo *MRI, + SPIRVGlobalRegistry &GR, + MachineInstr &I, unsigned OpIdx) { MachineFunction *MF = I.getParent()->getParent(); Register OpReg = I.getOperand(OpIdx).getReg(); Register OpTypeReg = getTypeReg(MRI, OpReg); @@ -440,8 +440,8 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { validateLifetimeStart(STI, MRI, GR, MI); break; case SPIRV::OpGroupAsyncCopy: - validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3); - validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4); + validatePtrUnwrapStructField(STI, MRI, GR, MI, 3); + validatePtrUnwrapStructField(STI, MRI, GR, MI, 4); break; case SPIRV::OpGroupWaitEvents: // OpGroupWaitEvents ..., ..., @@ -467,6 +467,49 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { if (Type->getParent() == Curr && !Curr->pred_empty()) ToMove.insert(const_cast(Type)); } break; + case SPIRV::OpExtInst: { + // prefetch + if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() || + MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std) + continue; + switch (MI.getOperand(3).getImm()) { + case SPIRV::OpenCLExtInst::remquo: { + // The last operand must be of a pointer to the return type. + MachineIRBuilder MIB(MI); + SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB); + SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg()); + assert(RetType && "Expected return type"); + validatePtrTypes( + STI, MRI, GR, MI, MI.getNumOperands() - 1, + RetType->getOpcode() != SPIRV::OpTypeVector + ? Int32Type + : GR.getOrCreateSPIRVVectorType( + Int32Type, RetType->getOperand(2).getImm(), MIB)); + } break; + case SPIRV::OpenCLExtInst::fract: + case SPIRV::OpenCLExtInst::frexp: + case SPIRV::OpenCLExtInst::lgamma_r: + case SPIRV::OpenCLExtInst::modf: + case SPIRV::OpenCLExtInst::sincos: + // The last operand must be of a pointer to the base type represented + // by the previous operand. + assert(MI.getOperand(MI.getNumOperands() - 2).isReg() && + "Expected v-reg"); + validatePtrTypes( + STI, MRI, GR, MI, MI.getNumOperands() - 1, + GR.getSPIRVTypeForVReg( + MI.getOperand(MI.getNumOperands() - 2).getReg())); + break; + case SPIRV::OpenCLExtInst::prefetch: + // Expected `ptr` type is a pointer to float, integer or vector, but + // the pontee value can be wrapped into a struct. + assert(MI.getOperand(MI.getNumOperands() - 2).isReg() && + "Expected v-reg"); + validatePtrUnwrapStructField(STI, MRI, GR, MI, + MI.getNumOperands() - 2); + break; + } + } break; } } for (MachineInstr *MI : ToMove) { diff --git a/llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll b/llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll new file mode 100644 index 0000000000000..8e29876d61d33 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll @@ -0,0 +1,34 @@ +; The goal of the test is to ensure that the output SPIR-V is valid from the perspective of the spirv-val tool. +; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +%clsid = type { %arr } +%arr = type { [1 x i64] } +%struct_half = type { half } + +define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef readonly align 2 %_acc, ptr noundef byval(%clsid) align 8 %_acc_id, ptr addrspace(3) noundef align 2 %_arg_loc) { +entry: + %r1 = load i64, ptr %_acc_id, align 8 + %add.ptr.i41 = getelementptr inbounds %struct_half, ptr addrspace(1) %_acc, i64 %r1 + %idx = addrspacecast ptr addrspace(1) %add.ptr.i41 to ptr addrspace(4) + %call.i.i290 = call spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvi(ptr addrspace(4) noundef %idx, i32 noundef 5) + call spir_func void @_Z20__spirv_ocl_prefetchPU3AS1Kcm(ptr addrspace(1) noundef %call.i.i290, i64 noundef 0) + + %locidx = addrspacecast ptr addrspace(3) %_arg_loc to ptr addrspace(4) + %ptr1 = tail call spir_func noundef ptr addrspace(3) @_Z40__spirv_GenericCastToPtrExplicit_ToLocalPvi(ptr addrspace(4) noundef %locidx, i32 noundef 4) + %sincos_r = tail call spir_func noundef half @_Z18__spirv_ocl_sincosDF16_PU3AS3DF16_(half noundef 0xH3145, ptr addrspace(3) noundef %ptr1) + + %p1 = addrspacecast ptr addrspace(1) %_acc to ptr addrspace(4) + %ptr2 = tail call spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPvi(ptr addrspace(4) noundef %p1, i32 noundef 5) + %remquo_r = tail call spir_func noundef half @_Z18__spirv_ocl_remquoDF16_DF16_PU3AS1i(half noundef 0xH3A37, half noundef 0xH32F4, ptr addrspace(1) noundef %ptr2) + + ret void +} + +declare dso_local spir_func void @_Z20__spirv_ocl_prefetchPU3AS1Kcm(ptr addrspace(1) noundef, i64 noundef) +declare dso_local spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvi(ptr addrspace(4) noundef, i32 noundef) + +declare dso_local spir_func noundef half @_Z18__spirv_ocl_sincosDF16_PU3AS3DF16_(half noundef, ptr addrspace(3) noundef) +declare dso_local spir_func noundef ptr addrspace(3) @_Z40__spirv_GenericCastToPtrExplicit_ToLocalPvi(ptr addrspace(4) noundef, i32 noundef) + +declare dso_local spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPvi(ptr addrspace(4) noundef, i32 noundef) +declare dso_local spir_func noundef half @_Z18__spirv_ocl_remquoDF16_DF16_PU3AS1i(half noundef, half noundef, ptr addrspace(1) noundef)