From 204cf4b4b1107a3d864ce22c147740bd2be77a3f Mon Sep 17 00:00:00 2001 From: ebinjose02 Date: Fri, 28 Feb 2025 17:39:46 +0530 Subject: [PATCH 1/2] FEAT : - Added lowering for llvm intrinsic sadd.with.overflow - Used the logic that overflow occurs when both operands are greater than zero and sum is less than any of the operands. - Likewise overflow occurs when both operands are less than zero and sum is greater than any of the operands. --- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 185 +++++++++++++++++- .../llvm-intrinsics/sadd.with.overflow.ll | 161 +++++++++++++++ 2 files changed, 345 insertions(+), 1 deletion(-) create mode 100644 llvm/test/CodeGen/SPIRV/llvm-intrinsics/sadd.with.overflow.ll diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index c52b67e72a88c..8f6efc3b5f43c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -202,6 +202,9 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectOverflowArith(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, unsigned Opcode) const; + bool selectSignedOverflowArith(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, bool isVector) const; + bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool Signed) const; @@ -509,7 +512,6 @@ static bool mayApplyGenericSelection(unsigned Opcode) { switch (Opcode) { case TargetOpcode::G_CONSTANT: return false; - case TargetOpcode::G_SADDO: case TargetOpcode::G_SSUBO: return true; } @@ -728,6 +730,11 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg, ResType->getOpcode() == SPIRV::OpTypeVector ? SPIRV::OpIAddCarryV : SPIRV::OpIAddCarryS); + case TargetOpcode::G_SADDO: + return selectSignedOverflowArith(ResVReg, ResType, I, + ResType->getOpcode() == SPIRV::OpTypeVector + ? true + : false); case TargetOpcode::G_USUBO: return selectOverflowArith(ResVReg, ResType, I, ResType->getOpcode() == SPIRV::OpTypeVector @@ -1376,6 +1383,182 @@ bool SPIRVInstructionSelector::selectOverflowArith(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } +bool SPIRVInstructionSelector::selectSignedOverflowArith(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + bool isVector) const { + + +//Checking overflow based on the logic that if two operands are positive and the sum is +//less than one of the operands then an overflow occured. Likewise if two operands are +//negative and if sum is greater than one operand then also overflow occured. + + Type *ResTy = nullptr; + StringRef ResName; + MachineIRBuilder MIRBuilder(I); + if (!GR.findValueAttrs(&I, ResTy, ResName)) + report_fatal_error( + "Not enough info to select the signed arithmetic instruction"); + if (!ResTy || !ResTy->isStructTy()) + report_fatal_error( + "Expect struct type result for the signed arithmetic instruction"); + + StructType *ResStructTy = cast(ResTy); + Type *ResElemTy = ResStructTy->getElementType(0); + Type *OverflowTy = ResStructTy->getElementType(1); + ResTy = StructType::get(ResElemTy, OverflowTy); + SPIRVType *StructType = GR.getOrCreateSPIRVType( + ResTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false); + if (!StructType) { + report_fatal_error("Failed to create SPIR-V type for struct"); + } + SPIRVType *BoolType = GR.getOrCreateSPIRVBoolType(I, TII); + unsigned N = GR.getScalarOrVectorComponentCount(ResType); + if (N > 1) + BoolType = GR.getOrCreateSPIRVVectorType(BoolType, N, I, TII); + Register BoolTypeReg = GR.getSPIRVTypeID(BoolType); + Register ZeroReg = buildZerosVal(ResType, I); + Register StructVReg = MRI->createGenericVirtualRegister(LLT::scalar(64)); + MRI->setRegClass(StructVReg, &SPIRV::IDRegClass); + + if (ResName.size() > 0) + buildOpName(StructVReg, ResName, MIRBuilder); + + MachineBasicBlock &BB = *I.getParent(); + Register SumVReg = MRI->createGenericVirtualRegister(LLT::scalar(64)); + MRI->setRegClass(SumVReg, &SPIRV::IDRegClass); + SPIRVType *IntType = GR.getOrCreateSPIRVType(ResElemTy,MIRBuilder,SPIRV::AccessQualifier::ReadWrite,true); + + auto SumMIB = BuildMI(BB, MIRBuilder.getInsertPt(), I.getDebugLoc(), TII.get(isVector ? SPIRV::OpIAddV : SPIRV::OpIAddS)) + .addDef(SumVReg) + .addUse(GR.getSPIRVTypeID(IntType)); + for (unsigned i = I.getNumDefs(); i < I.getNumOperands(); ++i) + SumMIB.addUse(I.getOperand(i).getReg()); + bool Result = SumMIB.constrainAllUses(TII, TRI, RBI); + + Register OverflowVReg = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(OverflowVReg, &SPIRV::IDRegClass); + unsigned i = I.getNumDefs(); + + Register posCheck1 = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(posCheck1, &SPIRV::IDRegClass); + Register posCheck2 = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(posCheck2, &SPIRV::IDRegClass); + Register posCheck3 = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(posCheck3, &SPIRV::IDRegClass); + Register posOverflow = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(posOverflow, &SPIRV::IDRegClass); + Register posOverflowCheck = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(posOverflowCheck, &SPIRV::IDRegClass); + + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSGreaterThan)) + .addDef(posCheck1) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(I.getOperand(i).getReg()) + .addUse(ZeroReg); + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSGreaterThan)) + .addDef(posCheck2) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(I.getOperand(i+1).getReg()) + .addUse(ZeroReg); + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSLessThan)) + .addDef(posCheck3) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(SumVReg) + .addUse(I.getOperand(i+1).getReg()); + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLogicalAnd)) + .addDef(posOverflow) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(posCheck1) + .addUse(posCheck2); + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLogicalAnd)) + .addDef(posOverflowCheck) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(posOverflow) + .addUse(posCheck3); + + Register negCheck1 = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(negCheck1, &SPIRV::IDRegClass); + Register negCheck2 = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(negCheck2, &SPIRV::IDRegClass); + Register negCheck3 = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(negCheck3, &SPIRV::IDRegClass); + Register negOverflow = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(negOverflow, &SPIRV::IDRegClass); + Register negOverflowCheck = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(negOverflowCheck, &SPIRV::IDRegClass); + + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSLessThan)) + .addDef(negCheck1) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(I.getOperand(i).getReg()) + .addUse(ZeroReg); + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSLessThan)) + .addDef(negCheck2) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(I.getOperand(i+1).getReg()) + .addUse(ZeroReg); + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSGreaterThan)) + .addDef(negCheck3) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(SumVReg) + .addUse(I.getOperand(i+1).getReg()); + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLogicalAnd)) + .addDef(negOverflow) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(negCheck1) + .addUse(negCheck2); + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLogicalAnd)) + .addDef(negOverflowCheck) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(negOverflow) + .addUse(negCheck3); + + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLogicalOr)) + .addDef(OverflowVReg) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(negOverflowCheck) + .addUse(posOverflowCheck); + + // Construct the result struct containing sum and overflow flag + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeConstruct)) + .addDef(StructVReg) + .addUse(GR.getSPIRVTypeID(StructType)) + .addUse(SumVReg) + .addUse(OverflowVReg); + + Register HigherVReg = MRI->createGenericVirtualRegister(LLT::scalar(64)); + MRI->setRegClass(HigherVReg, &SPIRV::iIDRegClass); + + for (unsigned i = 0; i < I.getNumDefs(); ++i) { + auto MIB = + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) + .addDef(i == 1 ? HigherVReg : I.getOperand(i).getReg()) + .addUse(i == 1 ? GR.getSPIRVTypeID(BoolType) : GR.getSPIRVTypeID(ResType)) + .addUse(StructVReg) + .addImm(i); + Result &= MIB.constrainAllUses(TII, TRI, RBI); + } + Register FalseReg = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(FalseReg, &SPIRV::IDRegClass); + + // Use OpConstantFalse to initialize it + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) + .addDef(FalseReg) + .addUse(GR.getSPIRVTypeID(BoolType)); + + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLogicalNotEqual)) + .addDef(I.getOperand(1).getReg()) + .addUse(BoolTypeReg) + .addUse(HigherVReg) + .addUse(FalseReg) + .constrainAllUses(TII, TRI, RBI); + return true; + + + +} + bool SPIRVInstructionSelector::selectAtomicCmpXchg(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { diff --git a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/sadd.with.overflow.ll b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/sadd.with.overflow.ll new file mode 100644 index 0000000000000..a2ccfa1d11523 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/sadd.with.overflow.ll @@ -0,0 +1,161 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +;===---------------------------------------------------------------------===// +; Type definitions. +; CHECK-DAG: %[[I16:.*]] = OpTypeInt 16 0 +; CHECK-DAG: %[[Bool:.*]] = OpTypeBool +; CHECK-DAG: %[[I32:.*]] = OpTypeInt 32 0 +; CHECK-DAG: %[[I64:.*]] = OpTypeInt 64 0 +; CHECK-DAG: %[[PtrI16:.*]] = OpTypePointer Function %[[I16]] +; CHECK-DAG: %[[PtrI32:.*]] = OpTypePointer Function %[[I32]] +; CHECK-DAG: %[[PtrI64:.*]] = OpTypePointer Function %[[I64]] +; CHECK-DAG: %[[StructI16:.*]] = OpTypeStruct %[[I16]] %[[Bool]] +; CHECK-DAG: %[[StructI32:.*]] = OpTypeStruct %[[I32]] %[[Bool]] +; CHECK-DAG: %[[StructI64:.*]] = OpTypeStruct %[[I64]] %[[Bool]] +; CHECK-DAG: %[[ZeroI16:.*]] = OpConstant %[[I16]] 0 +; CHECK-DAG: %[[ZeroI32:.*]] = OpConstant %[[I32]] 0 +; CHECK-DAG: %[[ZeroI64:.*]] = OpConstantNull %[[I64]] +; CHECK-DAG: %[[V4I32:.*]] = OpTypeVector %[[I32]] 4 +; CHECK-DAG: %[[V4Bool:.*]] = OpTypeVector %[[Bool]] 4 +; CHECK-DAG: %[[PtrV4I32:.*]] = OpTypePointer Function %[[V4I32]] +; CHECK-DAG: %[[StructV4I32:.*]] = OpTypeStruct %[[V4I32]] %[[V4Bool]] +; CHECK-DAG: %[[ZeroV4I32:.*]] = OpConstantNull %[[V4I32]] +;===---------------------------------------------------------------------===// +; Function for i16 sadd.with.overflow. +; CHECK: OpFunction +; CHECK: %[[A16:.*]] = OpFunctionParameter %[[I16]] +; CHECK: %[[B16:.*]] = OpFunctionParameter %[[I16]] +; CHECK: %[[Ptr16:.*]] = OpFunctionParameter %[[PtrI16]] +; CHECK: %[[Sum16:.*]] = OpIAdd %[[I16]] %[[A16]] %[[B16]] +; CHECK: %[[PosCmp16_1:.*]] = OpSGreaterThan %[[Bool]] %[[A16]] %[[ZeroI16]] +; CHECK: %[[PosCmp16_2:.*]] = OpSGreaterThan %[[Bool]] %[[B16]] %[[ZeroI16]] +; CHECK: %[[PosCmp16_3:.*]] = OpSLessThan %[[Bool]] %[[Sum16]] %[[B16]] +; CHECK: %[[PosCond16:.*]] = OpLogicalAnd %[[Bool]] %[[PosCmp16_1]] %[[PosCmp16_2]] +; CHECK: %[[PosOverflow16:.*]] = OpLogicalAnd %[[Bool]] %[[PosCond16]] %[[PosCmp16_3]] +; CHECK: %[[NegCmp16_1:.*]] = OpSLessThan %[[Bool]] %[[A16]] %[[ZeroI16]] +; CHECK: %[[NegCmp16_2:.*]] = OpSLessThan %[[Bool]] %[[B16]] %[[ZeroI16]] +; CHECK: %[[NegCmp16_3:.*]] = OpSGreaterThan %[[Bool]] %[[Sum16]] %[[B16]] +; CHECK: %[[NegCond16:.*]] = OpLogicalAnd %[[Bool]] %[[NegCmp16_1]] %[[NegCmp16_2]] +; CHECK: %[[NegOverflow16:.*]] = OpLogicalAnd %[[Bool]] %[[NegCond16]] %[[NegCmp16_3]] +; CHECK: %[[Overflow16:.*]] = OpLogicalOr %[[Bool]] %[[NegOverflow16]] %[[PosOverflow16]] +; CHECK: %[[Comp16:.*]] = OpCompositeConstruct %[[StructI16]] %[[Sum16]] %[[Overflow16]] +; CHECK: %[[ExtOver16:.*]] = OpCompositeExtract %[[Bool]] %[[Comp16]] 1 +; CHECK: %[[Final16:.*]] = OpLogicalNotEqual %[[Bool]] %[[ExtOver16]] %[[#]] +; CHECK: OpReturn +define spir_func void @smulo_i16(i16 %a, i16 %b, ptr nocapture %c) { +entry: + %umul = tail call { i16, i1 } @llvm.sadd.with.overflow.i16(i16 %a, i16 %b) + %cmp = extractvalue { i16, i1 } %umul, 1 + %umul.value = extractvalue { i16, i1 } %umul, 0 + %storemerge = select i1 %cmp, i16 0, i16 %umul.value + store i16 %storemerge, ptr %c, align 1 + ret void +} + +;===---------------------------------------------------------------------===// +; Function for i32 sadd.with.overflow. +; CHECK: OpFunction +; CHECK: %[[A32:.*]] = OpFunctionParameter %[[I32]] +; CHECK: %[[B32:.*]] = OpFunctionParameter %[[I32]] +; CHECK: %[[Ptr32:.*]] = OpFunctionParameter %[[PtrI32]] +; CHECK: %[[Sum32:.*]] = OpIAdd %[[I32]] %[[A32]] %[[B32]] +; CHECK: %[[PosCmp32_1:.*]] = OpSGreaterThan %[[Bool]] %[[A32]] %[[ZeroI32]] +; CHECK: %[[PosCmp32_2:.*]] = OpSGreaterThan %[[Bool]] %[[B32]] %[[ZeroI32]] +; CHECK: %[[PosCmp32_3:.*]] = OpSLessThan %[[Bool]] %[[Sum32]] %[[B32]] +; CHECK: %[[PosCond32:.*]] = OpLogicalAnd %[[Bool]] %[[PosCmp32_1]] %[[PosCmp32_2]] +; CHECK: %[[PosOverflow32:.*]] = OpLogicalAnd %[[Bool]] %[[PosCond32]] %[[PosCmp32_3]] +; CHECK: %[[NegCmp32_1:.*]] = OpSLessThan %[[Bool]] %[[A32]] %[[ZeroI32]] +; CHECK: %[[NegCmp32_2:.*]] = OpSLessThan %[[Bool]] %[[B32]] %[[ZeroI32]] +; CHECK: %[[NegCmp32_3:.*]] = OpSGreaterThan %[[Bool]] %[[Sum32]] %[[B32]] +; CHECK: %[[NegCond32:.*]] = OpLogicalAnd %[[Bool]] %[[NegCmp32_1]] %[[NegCmp32_2]] +; CHECK: %[[NegOverflow32:.*]] = OpLogicalAnd %[[Bool]] %[[NegCond32]] %[[NegCmp32_3]] +; CHECK: %[[Overflow32:.*]] = OpLogicalOr %[[Bool]] %[[NegOverflow32]] %[[PosOverflow32]] +; CHECK: %[[Comp32:.*]] = OpCompositeConstruct %[[StructI32]] %[[Sum32]] %[[Overflow32]] +; CHECK: %[[ExtOver32:.*]] = OpCompositeExtract %[[Bool]] %[[Comp32]] 1 +; CHECK: %[[Final32:.*]] = OpLogicalNotEqual %[[Bool]] %[[ExtOver32]] %[[#]] +; CHECK: OpReturn +define spir_func void @smulo_i32(i32 %a, i32 %b, ptr nocapture %c) { +entry: + %umul = tail call { i32, i1 } @llvm.sadd.with.overflow.i32(i32 %a, i32 %b) + %cmp = extractvalue { i32, i1 } %umul, 1 + %umul.value = extractvalue { i32, i1 } %umul, 0 + %storemerge = select i1 %cmp, i32 0, i32 %umul.value + store i32 %storemerge, ptr %c, align 4 + ret void +} + +;===---------------------------------------------------------------------===// +; Function for i64 sadd.with.overflow. +; CHECK: OpFunction +; CHECK: %[[A64:.*]] = OpFunctionParameter %[[I64]] +; CHECK: %[[B64:.*]] = OpFunctionParameter %[[I64]] +; CHECK: %[[Ptr64:.*]] = OpFunctionParameter %[[PtrI64]] +; CHECK: %[[Sum64:.*]] = OpIAdd %[[I64]] %[[A64]] %[[B64]] +; CHECK: %[[PosCmp64_1:.*]] = OpSGreaterThan %[[Bool]] %[[A64]] %[[ZeroI64]] +; CHECK: %[[PosCmp64_2:.*]] = OpSGreaterThan %[[Bool]] %[[B64]] %[[ZeroI64]] +; CHECK: %[[PosCmp64_3:.*]] = OpSLessThan %[[Bool]] %[[Sum64]] %[[B64]] +; CHECK: %[[PosCond64:.*]] = OpLogicalAnd %[[Bool]] %[[PosCmp64_1]] %[[PosCmp64_2]] +; CHECK: %[[PosOverflow64:.*]] = OpLogicalAnd %[[Bool]] %[[PosCond64]] %[[PosCmp64_3]] +; CHECK: %[[NegCmp64_1:.*]] = OpSLessThan %[[Bool]] %[[A64]] %[[ZeroI64]] +; CHECK: %[[NegCmp64_2:.*]] = OpSLessThan %[[Bool]] %[[B64]] %[[ZeroI64]] +; CHECK: %[[NegCmp64_3:.*]] = OpSGreaterThan %[[Bool]] %[[Sum64]] %[[B64]] +; CHECK: %[[NegCond64:.*]] = OpLogicalAnd %[[Bool]] %[[NegCmp64_1]] %[[NegCmp64_2]] +; CHECK: %[[NegOverflow64:.*]] = OpLogicalAnd %[[Bool]] %[[NegCond64]] %[[NegCmp64_3]] +; CHECK: %[[Overflow64:.*]] = OpLogicalOr %[[Bool]] %[[NegOverflow64]] %[[PosOverflow64]] +; CHECK: %[[Comp64:.*]] = OpCompositeConstruct %[[StructI64]] %[[Sum64]] %[[Overflow64]] +; CHECK: %[[ExtOver64:.*]] = OpCompositeExtract %[[Bool]] %[[Comp64]] 1 +; CHECK: %[[Final64:.*]] = OpLogicalNotEqual %[[Bool]] %[[ExtOver64]] %[[#]] +; CHECK: OpReturn +define spir_func void @smulo_i64(i64 %a, i64 %b, ptr nocapture %c) { +entry: + %umul = tail call { i64, i1 } @llvm.sadd.with.overflow.i64(i64 %a, i64 %b) + %cmp = extractvalue { i64, i1 } %umul, 1 + %umul.value = extractvalue { i64, i1 } %umul, 0 + %storemerge = select i1 %cmp, i64 0, i64 %umul.value + store i64 %storemerge, ptr %c, align 8 + ret void +} + +;===---------------------------------------------------------------------===// +; Function for vector (4 x i32) sadd.with.overflow. + +; CHECK: OpFunction +; CHECK: %[[A4:.*]] = OpFunctionParameter %[[V4I32]] +; CHECK: %[[B4:.*]] = OpFunctionParameter %[[V4I32]] +; CHECK: %[[Ptr4:.*]] = OpFunctionParameter %[[PtrV4I32]] +; CHECK: %[[Sum4:.*]] = OpIAdd %[[V4I32]] %[[A4]] %[[B4]] +; CHECK: %[[PosCmp4_1:.*]] = OpSGreaterThan %[[V4Bool]] %[[A4]] %[[ZeroV4I32]] +; CHECK: %[[PosCmp4_2:.*]] = OpSGreaterThan %[[V4Bool]] %[[B4]] %[[ZeroV4I32]] +; CHECK: %[[PosCmp4_3:.*]] = OpSLessThan %[[V4Bool]] %[[Sum4]] %[[B4]] +; CHECK: %[[PosCond4:.*]] = OpLogicalAnd %[[V4Bool]] %[[PosCmp4_1]] %[[PosCmp4_2]] +; CHECK: %[[PosOverflow4:.*]] = OpLogicalAnd %[[V4Bool]] %[[PosCond4]] %[[PosCmp4_3]] +; CHECK: %[[NegCmp4_1:.*]] = OpSLessThan %[[V4Bool]] %[[A4]] %[[ZeroV4I32]] +; CHECK: %[[NegCmp4_2:.*]] = OpSLessThan %[[V4Bool]] %[[B4]] %[[ZeroV4I32]] +; CHECK: %[[NegCmp4_3:.*]] = OpSGreaterThan %[[V4Bool]] %[[Sum4]] %[[B4]] +; CHECK: %[[NegCond4:.*]] = OpLogicalAnd %[[V4Bool]] %[[NegCmp4_1]] %[[NegCmp4_2]] +; CHECK: %[[NegOverflow4:.*]] = OpLogicalAnd %[[V4Bool]] %[[NegCond4]] %[[NegCmp4_3]] +; CHECK: %[[Overflow4:.*]] = OpLogicalOr %[[V4Bool]] %[[NegOverflow4]] %[[PosOverflow4]] +; CHECK: %[[Comp4:.*]] = OpCompositeConstruct %[[StructV4I32]] %[[Sum4]] %[[Overflow4]] +; CHECK: %[[ExtOver4:.*]] = OpCompositeExtract %[[V4Bool]] %[[Comp4]] 1 +; CHECK: %[[Final4:.*]] = OpLogicalNotEqual %[[V4Bool]] %[[ExtOver4]] %[[#]] +; CHECK: OpReturn +define spir_func void @smulo_v4i32(<4 x i32> %a, <4 x i32> %b, ptr nocapture %c) { +entry: + %umul = tail call { <4 x i32>, <4 x i1> } @llvm.sadd.with.overflow.v4i32(<4 x i32> %a, <4 x i32> %b) + %cmp = extractvalue { <4 x i32>, <4 x i1> } %umul, 1 + %umul.value = extractvalue { <4 x i32>, <4 x i1> } %umul, 0 + %storemerge = select <4 x i1> %cmp, <4 x i32> zeroinitializer, <4 x i32> %umul.value + store <4 x i32> %storemerge, ptr %c, align 16 + ret void +} + +;===---------------------------------------------------------------------===// +; Declarations of the intrinsics. +declare { i16, i1 } @llvm.sadd.with.overflow.i16(i16, i16) +declare { i32, i1 } @llvm.sadd.with.overflow.i32(i32, i32) +declare { i64, i1 } @llvm.sadd.with.overflow.i64(i64, i64) +declare { <4 x i32>, <4 x i1> } @llvm.sadd.with.overflow.v4i32(<4 x i32>, <4 x i32>) From 1dfb08ce291171de681c35905b874351569bafca Mon Sep 17 00:00:00 2001 From: EbinJose2002 Date: Tue, 4 Mar 2025 04:28:28 +0000 Subject: [PATCH 2/2] FEAT : - Formatted the code as per clang format --- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 158 +++++++++--------- 1 file changed, 78 insertions(+), 80 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 8f6efc3b5f43c..dce53b07b54ea 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -731,10 +731,9 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg, ? SPIRV::OpIAddCarryV : SPIRV::OpIAddCarryS); case TargetOpcode::G_SADDO: - return selectSignedOverflowArith(ResVReg, ResType, I, - ResType->getOpcode() == SPIRV::OpTypeVector - ? true - : false); + return selectSignedOverflowArith( + ResVReg, ResType, I, + ResType->getOpcode() == SPIRV::OpTypeVector ? true : false); case TargetOpcode::G_USUBO: return selectOverflowArith(ResVReg, ResType, I, ResType->getOpcode() == SPIRV::OpTypeVector @@ -1383,15 +1382,14 @@ bool SPIRVInstructionSelector::selectOverflowArith(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } -bool SPIRVInstructionSelector::selectSignedOverflowArith(Register ResVReg, - const SPIRVType *ResType, - MachineInstr &I, - bool isVector) const { - +bool SPIRVInstructionSelector::selectSignedOverflowArith( + Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + bool isVector) const { -//Checking overflow based on the logic that if two operands are positive and the sum is -//less than one of the operands then an overflow occured. Likewise if two operands are -//negative and if sum is greater than one operand then also overflow occured. + // Checking overflow based on the logic that if two operands are positive and + // the sum is less than one of the operands then an overflow occured. Likewise + // if two operands are negative and if sum is greater than one operand then + // also overflow occured. Type *ResTy = nullptr; StringRef ResName; @@ -1401,14 +1399,14 @@ bool SPIRVInstructionSelector::selectSignedOverflowArith(Register ResVReg, "Not enough info to select the signed arithmetic instruction"); if (!ResTy || !ResTy->isStructTy()) report_fatal_error( - "Expect struct type result for the signed arithmetic instruction"); - + "Expect struct type result for the signed arithmetic instruction"); + StructType *ResStructTy = cast(ResTy); Type *ResElemTy = ResStructTy->getElementType(0); Type *OverflowTy = ResStructTy->getElementType(1); ResTy = StructType::get(ResElemTy, OverflowTy); SPIRVType *StructType = GR.getOrCreateSPIRVType( - ResTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false); + ResTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false); if (!StructType) { report_fatal_error("Failed to create SPIR-V type for struct"); } @@ -1420,16 +1418,18 @@ bool SPIRVInstructionSelector::selectSignedOverflowArith(Register ResVReg, Register ZeroReg = buildZerosVal(ResType, I); Register StructVReg = MRI->createGenericVirtualRegister(LLT::scalar(64)); MRI->setRegClass(StructVReg, &SPIRV::IDRegClass); - + if (ResName.size() > 0) buildOpName(StructVReg, ResName, MIRBuilder); - + MachineBasicBlock &BB = *I.getParent(); Register SumVReg = MRI->createGenericVirtualRegister(LLT::scalar(64)); MRI->setRegClass(SumVReg, &SPIRV::IDRegClass); - SPIRVType *IntType = GR.getOrCreateSPIRVType(ResElemTy,MIRBuilder,SPIRV::AccessQualifier::ReadWrite,true); + SPIRVType *IntType = GR.getOrCreateSPIRVType( + ResElemTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true); - auto SumMIB = BuildMI(BB, MIRBuilder.getInsertPt(), I.getDebugLoc(), TII.get(isVector ? SPIRV::OpIAddV : SPIRV::OpIAddS)) + auto SumMIB = BuildMI(BB, MIRBuilder.getInsertPt(), I.getDebugLoc(), + TII.get(isVector ? SPIRV::OpIAddV : SPIRV::OpIAddS)) .addDef(SumVReg) .addUse(GR.getSPIRVTypeID(IntType)); for (unsigned i = I.getNumDefs(); i < I.getNumOperands(); ++i) @@ -1450,32 +1450,32 @@ bool SPIRVInstructionSelector::selectSignedOverflowArith(Register ResVReg, MRI->setRegClass(posOverflow, &SPIRV::IDRegClass); Register posOverflowCheck = MRI->createGenericVirtualRegister(LLT::scalar(1)); MRI->setRegClass(posOverflowCheck, &SPIRV::IDRegClass); - + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSGreaterThan)) - .addDef(posCheck1) - .addUse(GR.getSPIRVTypeID(BoolType)) - .addUse(I.getOperand(i).getReg()) - .addUse(ZeroReg); + .addDef(posCheck1) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(I.getOperand(i).getReg()) + .addUse(ZeroReg); BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSGreaterThan)) - .addDef(posCheck2) - .addUse(GR.getSPIRVTypeID(BoolType)) - .addUse(I.getOperand(i+1).getReg()) - .addUse(ZeroReg); + .addDef(posCheck2) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(I.getOperand(i + 1).getReg()) + .addUse(ZeroReg); BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSLessThan)) - .addDef(posCheck3) - .addUse(GR.getSPIRVTypeID(BoolType)) - .addUse(SumVReg) - .addUse(I.getOperand(i+1).getReg()); + .addDef(posCheck3) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(SumVReg) + .addUse(I.getOperand(i + 1).getReg()); BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLogicalAnd)) - .addDef(posOverflow) - .addUse(GR.getSPIRVTypeID(BoolType)) - .addUse(posCheck1) - .addUse(posCheck2); + .addDef(posOverflow) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(posCheck1) + .addUse(posCheck2); BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLogicalAnd)) - .addDef(posOverflowCheck) - .addUse(GR.getSPIRVTypeID(BoolType)) - .addUse(posOverflow) - .addUse(posCheck3); + .addDef(posOverflowCheck) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(posOverflow) + .addUse(posCheck3); Register negCheck1 = MRI->createGenericVirtualRegister(LLT::scalar(1)); MRI->setRegClass(negCheck1, &SPIRV::IDRegClass); @@ -1487,54 +1487,55 @@ bool SPIRVInstructionSelector::selectSignedOverflowArith(Register ResVReg, MRI->setRegClass(negOverflow, &SPIRV::IDRegClass); Register negOverflowCheck = MRI->createGenericVirtualRegister(LLT::scalar(1)); MRI->setRegClass(negOverflowCheck, &SPIRV::IDRegClass); - + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSLessThan)) - .addDef(negCheck1) - .addUse(GR.getSPIRVTypeID(BoolType)) - .addUse(I.getOperand(i).getReg()) - .addUse(ZeroReg); + .addDef(negCheck1) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(I.getOperand(i).getReg()) + .addUse(ZeroReg); BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSLessThan)) - .addDef(negCheck2) - .addUse(GR.getSPIRVTypeID(BoolType)) - .addUse(I.getOperand(i+1).getReg()) - .addUse(ZeroReg); + .addDef(negCheck2) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(I.getOperand(i + 1).getReg()) + .addUse(ZeroReg); BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSGreaterThan)) - .addDef(negCheck3) - .addUse(GR.getSPIRVTypeID(BoolType)) - .addUse(SumVReg) - .addUse(I.getOperand(i+1).getReg()); + .addDef(negCheck3) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(SumVReg) + .addUse(I.getOperand(i + 1).getReg()); BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLogicalAnd)) - .addDef(negOverflow) - .addUse(GR.getSPIRVTypeID(BoolType)) - .addUse(negCheck1) - .addUse(negCheck2); + .addDef(negOverflow) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(negCheck1) + .addUse(negCheck2); BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLogicalAnd)) - .addDef(negOverflowCheck) - .addUse(GR.getSPIRVTypeID(BoolType)) - .addUse(negOverflow) - .addUse(negCheck3); + .addDef(negOverflowCheck) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(negOverflow) + .addUse(negCheck3); BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLogicalOr)) - .addDef(OverflowVReg) - .addUse(GR.getSPIRVTypeID(BoolType)) - .addUse(negOverflowCheck) - .addUse(posOverflowCheck); - + .addDef(OverflowVReg) + .addUse(GR.getSPIRVTypeID(BoolType)) + .addUse(negOverflowCheck) + .addUse(posOverflowCheck); + // Construct the result struct containing sum and overflow flag BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeConstruct)) - .addDef(StructVReg) - .addUse(GR.getSPIRVTypeID(StructType)) - .addUse(SumVReg) - .addUse(OverflowVReg); + .addDef(StructVReg) + .addUse(GR.getSPIRVTypeID(StructType)) + .addUse(SumVReg) + .addUse(OverflowVReg); Register HigherVReg = MRI->createGenericVirtualRegister(LLT::scalar(64)); MRI->setRegClass(HigherVReg, &SPIRV::iIDRegClass); - + for (unsigned i = 0; i < I.getNumDefs(); ++i) { auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) .addDef(i == 1 ? HigherVReg : I.getOperand(i).getReg()) - .addUse(i == 1 ? GR.getSPIRVTypeID(BoolType) : GR.getSPIRVTypeID(ResType)) + .addUse(i == 1 ? GR.getSPIRVTypeID(BoolType) + : GR.getSPIRVTypeID(ResType)) .addUse(StructVReg) .addImm(i); Result &= MIB.constrainAllUses(TII, TRI, RBI); @@ -1548,15 +1549,12 @@ bool SPIRVInstructionSelector::selectSignedOverflowArith(Register ResVReg, .addUse(GR.getSPIRVTypeID(BoolType)); BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLogicalNotEqual)) - .addDef(I.getOperand(1).getReg()) - .addUse(BoolTypeReg) - .addUse(HigherVReg) - .addUse(FalseReg) - .constrainAllUses(TII, TRI, RBI); + .addDef(I.getOperand(1).getReg()) + .addUse(BoolTypeReg) + .addUse(HigherVReg) + .addUse(FalseReg) + .constrainAllUses(TII, TRI, RBI); return true; - - - } bool SPIRVInstructionSelector::selectAtomicCmpXchg(Register ResVReg,