diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h index 51318c9c2736d..7cce0ae5359b6 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h @@ -314,6 +314,11 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size); LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx, unsigned Size); +/// True iff the specified type index is a vector with an element size +/// that's greater than the given size. +LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx, + unsigned Size); + /// True iff the specified type index is a scalar or a vector with an element /// type that's wider than the given size. LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx, diff --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp index 30c2d089c3121..757a1fdba7fbe 100644 --- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp @@ -155,6 +155,16 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx, }; } +LegalityPredicate +LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx, + unsigned Size) { + + return [=](const LegalityQuery &Query) { + const LLT QueryTy = Query.Types[TypeIdx]; + return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size; + }; +} + LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx, unsigned Size) { return [=](const LegalityQuery &Query) { diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td index 2fde2b0bc0b1f..f93240dc35993 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td @@ -25,6 +25,11 @@ class Op Opcode, dag outs, dag ins, string asmstr, list pattern = let Pattern = pattern; } +class PureOp Opcode, dag outs, dag ins, string asmstr, + list pattern = []> : Op { + let hasSideEffects = 0; +} + class UnknownOp pattern = []> : Op<0, outs, ins, asmstr, pattern> { let isPseudo = 1; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index a61351eba03f8..799a82c96b0f0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -163,52 +163,74 @@ def OpExecutionModeId: Op<331, (outs), (ins ID:$entry, ExecutionMode:$mode, vari // 3.42.6 Type-Declaration Instructions -def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">; -def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">; -def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness), - "$type = OpTypeInt $width $signedness">; -def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops), - "$type = OpTypeFloat $width">; -def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount), - "$type = OpTypeVector $compType $compCount">; -def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount), - "$type = OpTypeMatrix $colType $colCount">; -def OpTypeImage: Op<25, (outs TYPE:$res), (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth, - i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, ImageFormat:$imFormat, variable_ops), - "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS $sampled $imFormat">; -def OpTypeSampler: Op<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">; -def OpTypeSampledImage: Op<27, (outs TYPE:$res), (ins TYPE:$imageType), - "$res = OpTypeSampledImage $imageType">; -def OpTypeArray: Op<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length), - "$type = OpTypeArray $elementType $length">; -def OpTypeRuntimeArray: Op<29, (outs TYPE:$type), (ins TYPE:$elementType), - "$type = OpTypeRuntimeArray $elementType">; -def OpTypeStruct: Op<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">; -def OpTypeStructContinuedINTEL: Op<6090, (outs), (ins variable_ops), - "OpTypeStructContinuedINTEL">; -def OpTypeOpaque: Op<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops), - "$res = OpTypeOpaque $name">; -def OpTypePointer: Op<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type), - "$res = OpTypePointer $storage $type">; -def OpTypeFunction: Op<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops), - "$funcType = OpTypeFunction $returnType">; -def OpTypeEvent: Op<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">; -def OpTypeDeviceEvent: Op<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">; -def OpTypeReserveId: Op<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">; -def OpTypeQueue: Op<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">; -def OpTypePipe: Op<38, (outs TYPE:$res), (ins AccessQualifier:$a), "$res = OpTypePipe $a">; -def OpTypeForwardPointer: Op<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass), - "OpTypeForwardPointer $ptrType $storageClass">; -def OpTypePipeStorage: Op<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">; -def OpTypeNamedBarrier: Op<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">; -def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins), - "$res = OpTypeAccelerationStructureNV">; -def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res), - (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols), - "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">; -def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res), - (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use), - "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">; +def OpTypeVoid : PureOp<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">; +def OpTypeBool : PureOp<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">; +def OpTypeInt + : PureOp<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness), + "$type = OpTypeInt $width $signedness">; +def OpTypeFloat + : PureOp<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops), + "$type = OpTypeFloat $width">; +def OpTypeVector + : PureOp<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount), + "$type = OpTypeVector $compType $compCount">; +def OpTypeMatrix + : PureOp<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount), + "$type = OpTypeMatrix $colType $colCount">; +def OpTypeImage : PureOp<25, (outs TYPE:$res), + (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth, + i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, + ImageFormat:$imFormat, variable_ops), + "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS " + "$sampled $imFormat">; +def OpTypeSampler : PureOp<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">; +def OpTypeSampledImage : PureOp<27, (outs TYPE:$res), (ins TYPE:$imageType), + "$res = OpTypeSampledImage $imageType">; +def OpTypeArray + : PureOp<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length), + "$type = OpTypeArray $elementType $length">; +def OpTypeRuntimeArray : PureOp<29, (outs TYPE:$type), (ins TYPE:$elementType), + "$type = OpTypeRuntimeArray $elementType">; +def OpTypeStruct + : PureOp<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">; +def OpTypeStructContinuedINTEL + : PureOp<6090, (outs), (ins variable_ops), "OpTypeStructContinuedINTEL">; +def OpTypeOpaque + : PureOp<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops), + "$res = OpTypeOpaque $name">; +def OpTypePointer + : PureOp<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type), + "$res = OpTypePointer $storage $type">; +def OpTypeFunction + : PureOp<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops), + "$funcType = OpTypeFunction $returnType">; +def OpTypeEvent : PureOp<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">; +def OpTypeDeviceEvent + : PureOp<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">; +def OpTypeReserveId + : PureOp<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">; +def OpTypeQueue : PureOp<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">; +def OpTypePipe : PureOp<38, (outs TYPE:$res), (ins AccessQualifier:$a), + "$res = OpTypePipe $a">; +def OpTypeForwardPointer + : PureOp<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass), + "OpTypeForwardPointer $ptrType $storageClass">; +def OpTypePipeStorage + : PureOp<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">; +def OpTypeNamedBarrier + : PureOp<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">; +def OpTypeAccelerationStructureNV + : PureOp<5341, (outs TYPE:$res), (ins), + "$res = OpTypeAccelerationStructureNV">; +def OpTypeCooperativeMatrixNV + : PureOp<5358, (outs TYPE:$res), + (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols), + "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">; +def OpTypeCooperativeMatrixKHR + : PureOp<4456, (outs TYPE:$res), + (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use), + "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols " + "$use">; // 3.42.7 Constant-Creation Instructions @@ -222,31 +244,46 @@ defm OpConstant: IntFPImm<43, "OpConstant">; def ConstPseudoTrue: IntImmLeaf; def ConstPseudoFalse: IntImmLeaf; -def OpConstantTrue: Op<41, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantTrue $src_ty", - [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>; -def OpConstantFalse: Op<42, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantFalse $src_ty", - [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>; - -def OpConstantComposite: Op<44, (outs ID:$res), (ins TYPE:$type, variable_ops), - "$res = OpConstantComposite $type">; -def OpConstantCompositeContinuedINTEL: Op<6091, (outs), (ins variable_ops), - "OpConstantCompositeContinuedINTEL">; - -def OpConstantSampler: Op<45, (outs ID:$res), - (ins TYPE:$t, SamplerAddressingMode:$s, i32imm:$p, SamplerFilterMode:$f), - "$res = OpConstantSampler $t $s $p $f">; -def OpConstantNull: Op<46, (outs ID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantNull $src_ty">; - -def OpSpecConstantTrue: Op<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">; -def OpSpecConstantFalse: Op<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">; -def OpSpecConstant: Op<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops), - "$res = OpSpecConstant $type $imm">; -def OpSpecConstantComposite: Op<51, (outs ID:$res), (ins TYPE:$type, variable_ops), - "$res = OpSpecConstantComposite $type">; -def OpSpecConstantCompositeContinuedINTEL: Op<6092, (outs), (ins variable_ops), - "OpSpecConstantCompositeContinuedINTEL">; -def OpSpecConstantOp: Op<52, (outs ID:$res), (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops), - "$res = OpSpecConstantOp $t $c $o">; +def OpConstantTrue + : PureOp<41, (outs iID:$dst), (ins TYPE:$src_ty), + "$dst = OpConstantTrue $src_ty", + [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>; +def OpConstantFalse + : PureOp<42, (outs iID:$dst), (ins TYPE:$src_ty), + "$dst = OpConstantFalse $src_ty", + [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>; + +def OpConstantComposite + : PureOp<44, (outs ID:$res), (ins TYPE:$type, variable_ops), + "$res = OpConstantComposite $type">; +def OpConstantCompositeContinuedINTEL + : PureOp<6091, (outs), (ins variable_ops), + "OpConstantCompositeContinuedINTEL">; + +def OpConstantSampler : PureOp<45, (outs ID:$res), + (ins TYPE:$t, SamplerAddressingMode:$s, + i32imm:$p, SamplerFilterMode:$f), + "$res = OpConstantSampler $t $s $p $f">; +def OpConstantNull : PureOp<46, (outs ID:$dst), (ins TYPE:$src_ty), + "$dst = OpConstantNull $src_ty">; + +def OpSpecConstantTrue + : PureOp<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">; +def OpSpecConstantFalse + : PureOp<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">; +def OpSpecConstant + : PureOp<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops), + "$res = OpSpecConstant $type $imm">; +def OpSpecConstantComposite + : PureOp<51, (outs ID:$res), (ins TYPE:$type, variable_ops), + "$res = OpSpecConstantComposite $type">; +def OpSpecConstantCompositeContinuedINTEL + : PureOp<6092, (outs), (ins variable_ops), + "OpSpecConstantCompositeContinuedINTEL">; +def OpSpecConstantOp + : PureOp<52, (outs ID:$res), + (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops), + "$res = OpSpecConstantOp $t $c $o">; // 3.42.8 Memory Instructions diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 021353ab716f7..fd3d050f47060 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -594,6 +594,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) { bool HasDefs = I.getNumDefs() > 0; Register ResVReg = HasDefs ? I.getOperand(0).getReg() : Register(0); SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr; + I.dump(); assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE || I.getOpcode() == TargetOpcode::G_IMPLICIT_DEF); if (spvSelect(ResVReg, ResType, I)) { @@ -1526,33 +1527,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const { unsigned ArgI = I.getNumOperands() - 1; Register SrcReg = I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0); - SPIRVType *DefType = + SPIRVType *SrcType = SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr; - if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector) + if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector) report_fatal_error( "cannot select G_UNMERGE_VALUES with a non-vector argument"); SPIRVType *ScalarType = - GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); + GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg()); MachineBasicBlock &BB = *I.getParent(); bool Res = false; + unsigned CurrentIndex = 0; for (unsigned i = 0; i < I.getNumDefs(); ++i) { Register ResVReg = I.getOperand(i).getReg(); SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg); if (!ResType) { - // There was no "assign type" actions, let's fix this now - ResType = ScalarType; + LLT ResLLT = MRI->getType(ResVReg); + assert(ResLLT.isValid()); + if (ResLLT.isVector()) { + ResType = GR.getOrCreateSPIRVVectorType( + ScalarType, ResLLT.getNumElements(), I, TII); + } else { + ResType = ScalarType; + } MRI->setRegClass(ResVReg, GR.getRegClass(ResType)); - MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType))); GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF); } - auto MIB = - BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(SrcReg) - .addImm(static_cast(i)); - Res |= MIB.constrainAllUses(TII, TRI, RBI); + + if (ResType->getOpcode() == SPIRV::OpTypeVector) { + Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII); + auto MIB = + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SrcReg) + .addUse(UndefReg); + unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType); + for (unsigned j = 0; j < NumElements; ++j) { + MIB.addImm(CurrentIndex + j); + } + CurrentIndex += NumElements; + Res |= MIB.constrainAllUses(TII, TRI, RBI); + } else { + auto MIB = + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SrcReg) + .addImm(CurrentIndex); + CurrentIndex++; + Res |= MIB.constrainAllUses(TII, TRI, RBI); + } } return Res; } @@ -3119,6 +3144,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectInsertElt(ResVReg, ResType, I); case Intrinsic::spv_gep: return selectGEP(ResVReg, ResType, I); + case Intrinsic::spv_bitcast: { + Register OpReg = I.getOperand(2).getReg(); + SPIRVType *OpType = + OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr; + if (!GR.isBitcastCompatible(ResType, OpType)) + report_fatal_error("incompatible result and operand types in a bitcast"); + return selectOpWithSrcs(ResVReg, ResType, I, {OpReg}, SPIRV::OpBitcast); + } case Intrinsic::spv_unref_global: case Intrinsic::spv_init_global: { MachineInstr *MI = MRI->getVRegDef(I.getOperand(1).getReg()); diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp index 28a1690ef0be1..61de82afad389 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp @@ -73,7 +73,10 @@ class SPIRVLegalizePointerCast : public FunctionPass { // Returns the loaded value. Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType, FixedVectorType *TargetType, Value *Source) { - assert(TargetType->getNumElements() <= SourceType->getNumElements()); + const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout(); + [[maybe_unused]] TypeSize TargetTypeSize = DL.getTypeSizeInBits(TargetType); + [[maybe_unused]] TypeSize SourceTypeSize = DL.getTypeSizeInBits(SourceType); + assert(TargetTypeSize <= SourceTypeSize); LoadInst *NewLoad = B.CreateLoad(SourceType, Source); buildAssignType(B, SourceType, NewLoad); Value *AssignValue = NewLoad; diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 53074ea3b2597..fb623a2c10fe9 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -14,11 +14,13 @@ #include "SPIRV.h" #include "SPIRVGlobalRegistry.h" #include "SPIRVSubtarget.h" +#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/IR/IntrinsicsSPIRV.h" using namespace llvm; using namespace llvm::LegalizeActions; @@ -101,6 +103,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, + v3s1, v3s8, v3s16, v3s32, v3s64, + v4s1, v4s8, v4s16, v4s32, v4s64}; + + auto allNonShaderVectors = {v8s1, v8s8, v8s16, v8s32, v8s64, + v16s1, v16s8, v16s16, v16s32, v16s64}; + auto allScalarsAndVectors = { s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, @@ -148,15 +157,46 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { return IsExtendedInts && Ty.isValid(); }; - for (auto Opc : getTypeFoldingSupportedOpcodes()) - getActionDefinitionsBuilder(Opc).custom(); + // TODO: So far we only legalize vectors for Shaders. + // We need to legalize for kernels as well. For Kernels + // vector sizes of 8 and 16 are allowed as well. - getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); + for (auto Opc : getTypeFoldingSupportedOpcodes()) { + if (Opc != G_EXTRACT_VECTOR_ELT) + getActionDefinitionsBuilder(Opc).custom(); + } - // TODO: add proper rules for vectors legalization. - getActionDefinitionsBuilder( - {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR}) - .alwaysLegal(); + if (ST.canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450)) { + getActionDefinitionsBuilder(G_SHUFFLE_VECTOR) + .lowerIf(vectorElementCountIsGreaterThan(0, 4)) + .lowerIf(vectorElementCountIsGreaterThan(1, 4)) + .legalFor(allShaderVectors); + getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT) + .moreElementsToNextPow2(1) + .fewerElementsIf(vectorElementCountIsGreaterThan(1, 4), + LegalizeMutations::changeElementCountTo( + 1, ElementCount::getFixed(4))) + .custom(); + getActionDefinitionsBuilder(G_BUILD_VECTOR) + .legalFor(allShaderVectors) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, 4), + LegalizeMutations::changeElementCountTo( + 0, ElementCount::getFixed(4))); + getActionDefinitionsBuilder(G_BITCAST) + .moreElementsToNextPow2(0) + .moreElementsToNextPow2(1) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, 4), + LegalizeMutations::changeElementCountTo( + 0, ElementCount::getFixed(4))) + .lowerIf(vectorElementCountIsGreaterThan(1, 4)) + .custom(); + getActionDefinitionsBuilder(G_CONCAT_VECTORS) + .legalFor(allShaderVectors) + .lower(); + } else + getActionDefinitionsBuilder( + {G_SHUFFLE_VECTOR, G_BUILD_VECTOR, G_SPLAT_VECTOR}) + .alwaysLegal(); // Vector Reduction Operations getActionDefinitionsBuilder( @@ -287,6 +327,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { // Pointer-handling. getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); + getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs); + // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32. getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32}); @@ -374,6 +416,11 @@ bool SPIRVLegalizerInfo::legalizeCustom( default: // TODO: implement legalization for other opcodes. return true; + case TargetOpcode::G_BITCAST: + return legalizeBitcast(Helper, MI); + case TargetOpcode::G_INTRINSIC: + return legalizeIntrinsic(Helper, MI); + case TargetOpcode::G_IS_FPCLASS: return legalizeIsFPClass(Helper, MI, LocObserver); case TargetOpcode::G_ICMP: { @@ -400,6 +447,41 @@ bool SPIRVLegalizerInfo::legalizeCustom( } } +bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, + MachineInstr &MI) const { + MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; + MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); + + auto IntrinsicID = cast(MI).getIntrinsicID(); + if (IntrinsicID == Intrinsic::spv_bitcast) { + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(2).getReg(); + LLT DstTy = MRI.getType(DstReg); + LLT SrcTy = MRI.getType(SrcReg); + + bool isLongVector = (DstTy.isVector() && DstTy.getNumElements() > 4) || + (SrcTy.isVector() && SrcTy.getNumElements() > 4); + + if (isLongVector) { + MIRBuilder.buildBitcast(DstReg, SrcReg); + MI.eraseFromParent(); + } + return true; + } + return true; +} + +bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper, + MachineInstr &MI) const { + MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(1).getReg(); + SmallVector DstRegs = {DstReg}; + MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg); + MI.eraseFromParent(); + return true; +} + // Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted // to ensure that all instructions created during the lowering have SPIR-V types // assigned to them. diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h index eeefa4239c778..86e7e711caa60 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h @@ -29,11 +29,15 @@ class SPIRVLegalizerInfo : public LegalizerInfo { public: bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override; + bool legalizeIntrinsic(LegalizerHelper &Helper, + MachineInstr &MI) const override; + SPIRVLegalizerInfo(const SPIRVSubtarget &ST); private: bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const; + bool legalizeBitcast(LegalizerHelper &Helper, MachineInstr &MI) const; }; } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp index d17528dd882bf..dc01c37594be0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp @@ -17,7 +17,8 @@ #include "SPIRV.h" #include "SPIRVSubtarget.h" #include "SPIRVUtils.h" -#include "llvm/IR/Attributes.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "spirv-postlegalizer" @@ -45,6 +46,11 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, static bool mayBeInserted(unsigned Opcode) { switch (Opcode) { + case TargetOpcode::G_CONSTANT: + case TargetOpcode::G_UNMERGE_VALUES: + case TargetOpcode::G_EXTRACT_VECTOR_ELT: + case TargetOpcode::G_INTRINSIC: + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: case TargetOpcode::G_SMAX: case TargetOpcode::G_UMAX: case TargetOpcode::G_SMIN: @@ -53,70 +59,234 @@ static bool mayBeInserted(unsigned Opcode) { case TargetOpcode::G_FMINIMUM: case TargetOpcode::G_FMAXNUM: case TargetOpcode::G_FMAXIMUM: + case TargetOpcode::G_IMPLICIT_DEF: + case TargetOpcode::G_BUILD_VECTOR: return true; default: return isTypeFoldingSupported(Opcode); } } -static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, - MachineIRBuilder MIB) { +static bool processInstr(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB) { MachineRegisterInfo &MRI = MF.getRegInfo(); + const unsigned Opcode = I->getOpcode(); + Register ResVReg = I->getOperand(0).getReg(); + SPIRVType *ResType = nullptr; + bool Handled = false; - for (MachineBasicBlock &MBB : MF) { - for (MachineInstr &I : MBB) { - const unsigned Opcode = I.getOpcode(); - if (Opcode == TargetOpcode::G_UNMERGE_VALUES) { - unsigned ArgI = I.getNumOperands() - 1; - Register SrcReg = I.getOperand(ArgI).isReg() - ? I.getOperand(ArgI).getReg() - : Register(0); - SPIRVType *DefType = - SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr; - if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector) - report_fatal_error( - "cannot select G_UNMERGE_VALUES with a non-vector argument"); + switch (Opcode) { + case TargetOpcode::G_CONSTANT: { + const LLT &Ty = MRI.getType(ResVReg); + unsigned BitWidth = Ty.getScalarSizeInBits(); + ResType = GR->getOrCreateSPIRVIntegerType(BitWidth, MIB); + Handled = true; + break; + } + case TargetOpcode::G_UNMERGE_VALUES: { + Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg(); + if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) { + if (DefType->getOpcode() == SPIRV::OpTypeVector) { SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); - for (unsigned i = 0; i < I.getNumDefs(); ++i) { - Register ResVReg = I.getOperand(i).getReg(); - SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg); - if (!ResType) { - // There was no "assign type" actions, let's fix this now - ResType = ScalarType; - setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true); + for (unsigned i = 0; i < I->getNumDefs(); ++i) { + Register DefReg = I->getOperand(i).getReg(); + if (!GR->getSPIRVTypeForVReg(DefReg)) { + LLT DefLLT = MRI.getType(DefReg); + SPIRVType *ResType; + if (DefLLT.isVector()) { + const SPIRVInstrInfo *TII = + MF.getSubtarget().getInstrInfo(); + ResType = GR->getOrCreateSPIRVVectorType( + ScalarType, DefLLT.getNumElements(), *I, *TII); + } else { + ResType = ScalarType; + } + setRegClassType(DefReg, ResType, GR, &MRI, MF); } } - } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 && - I.getNumOperands() > 1 && I.getOperand(1).isReg()) { - // Legalizer may have added a new instructions and introduced new - // registers, we must decorate them as if they were introduced in a - // non-automatic way - Register ResVReg = I.getOperand(0).getReg(); - // Check if the register defined by the instruction is newly generated - // or already processed - // Check if we have type defined for operands of the new instruction - bool IsKnownReg = MRI.getRegClassOrNull(ResVReg); - SPIRVType *ResVType = GR->getSPIRVTypeForVReg( - IsKnownReg ? ResVReg : I.getOperand(1).getReg()); - if (!ResVType) - continue; - // Set type & class - if (!IsKnownReg) - setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true); - // If this is a simple operation that is to be reduced by TableGen - // definition we must apply some of pre-legalizer rules here - if (isTypeFoldingSupported(Opcode)) { - processInstr(I, MIB, MRI, GR, GR->getSPIRVTypeForVReg(ResVReg)); - if (IsKnownReg && MRI.hasOneUse(ResVReg)) { - MachineInstr &UseMI = *MRI.use_instr_begin(ResVReg); - if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE) - continue; + Handled = true; + } + } + break; + } + case TargetOpcode::G_EXTRACT_VECTOR_ELT: { + LLVM_DEBUG(dbgs() << "Processing G_EXTRACT_VECTOR_ELT: " << *I); + Register VecReg = I->getOperand(1).getReg(); + if (SPIRVType *VecType = GR->getSPIRVTypeForVReg(VecReg)) { + LLVM_DEBUG(dbgs() << " Found vector type: " << *VecType << "\n"); + if (VecType->getOpcode() != SPIRV::OpTypeVector) { + VecType->dump(); + } + assert(VecType->getOpcode() == SPIRV::OpTypeVector); + ResType = GR->getScalarOrVectorComponentType(VecType); + Handled = true; + } else { + LLVM_DEBUG(dbgs() << " Vector operand " << VecReg + << " has no type. Looking at uses of " << ResVReg + << ".\n"); + // If not handled yet, then check if it is used in a G_BUILD_VECTOR. + // If so get the type from there. + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + LLVM_DEBUG(dbgs() << " Use: " << Use); + if (Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR) { + LLVM_DEBUG(dbgs() << " Use is G_BUILD_VECTOR.\n"); + Register BuildVecResReg = Use.getOperand(0).getReg(); + if (SPIRVType *BuildVecType = + GR->getSPIRVTypeForVReg(BuildVecResReg)) { + LLVM_DEBUG(dbgs() << " Found G_BUILD_VECTOR result type: " + << *BuildVecType << "\n"); + ResType = GR->getScalarOrVectorComponentType(BuildVecType); + Handled = true; + break; + } else { + LLVM_DEBUG(dbgs() << " G_BUILD_VECTOR result " << BuildVecResReg + << " has no type yet.\n"); } - insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI); } } } + if (!Handled) { + LLVM_DEBUG( + dbgs() << " Could not determine type for G_EXTRACT_VECTOR_ELT.\n"); + } + break; + } + case TargetOpcode::G_BUILD_VECTOR: { + // First check if any of the operands have a type. + for (unsigned i = 1; i < I->getNumOperands(); ++i) { + if (SPIRVType *OpType = + GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) { + const LLT &ResLLT = MRI.getType(ResVReg); + ResType = GR->getOrCreateSPIRVVectorType( + OpType, ResLLT.getNumElements(), MIB, false); + Handled = true; + break; + } + } + if (Handled) { + break; + } + // If that did not work, then check the uses. + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + if (Use.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) { + Register ExtractResReg = Use.getOperand(0).getReg(); + if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) { + const LLT &ResLLT = MRI.getType(ResVReg); + ResType = GR->getOrCreateSPIRVVectorType( + ScalarType, ResLLT.getNumElements(), MIB, false); + Handled = true; + break; + } + } + } + break; + } + case TargetOpcode::G_IMPLICIT_DEF: { + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + const unsigned UseOpc = Use.getOpcode(); + assert(UseOpc == TargetOpcode::G_BUILD_VECTOR || + UseOpc == TargetOpcode::G_SHUFFLE_VECTOR); + // It's possible that the use instruction has not been processed yet. + // We should look at the operands of the use to determine the type. + for (unsigned i = 1; i < Use.getNumOperands(); ++i) { + if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg())) { + ResType = Type; + Handled = true; + break; + } + } + if (Handled) { + break; + } + } + break; + } + case TargetOpcode::G_INTRINSIC: + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: { + if (!isSpvIntrinsic(*I, Intrinsic::spv_bitcast)) + break; + + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + const unsigned UseOpc = Use.getOpcode(); + assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT || + UseOpc == TargetOpcode::G_SHUFFLE_VECTOR); + Register UseResultReg = Use.getOperand(0).getReg(); + if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) { + SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType); + const LLT &BitcastLLT = MRI.getType(ResVReg); + if (BitcastLLT.isVector()) { + ResType = GR->getOrCreateSPIRVVectorType( + ScalarType, BitcastLLT.getNumElements(), MIB, false); + } else { + ResType = ScalarType; + } + Handled = true; + break; + } + } + break; + } + default: + if (I->getNumDefs() == 1 && I->getNumOperands() > 1 && + I->getOperand(1).isReg()) { + if (SPIRVType *OpType = + GR->getSPIRVTypeForVReg(I->getOperand(1).getReg())) { + ResType = OpType; + Handled = true; + } + } + break; + } + + if (Handled && ResType) { + LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n"); + GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF); + } + return Handled; +} + +static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, + MachineIRBuilder MIB) { + SmallVector Worklist; + + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &I : MBB) { + if (I.getNumDefs() > 0 && + !GR->getSPIRVTypeForVReg(I.getOperand(0).getReg()) && + mayBeInserted(I.getOpcode())) { + Worklist.push_back(&I); + } + } + } + + if (Worklist.empty()) { + return; + } + + LLVM_DEBUG(dbgs() << "Initial worklist:\n"; + for (auto *I : Worklist) { I->dump(); }); + + bool Changed = true; + while (Changed) { + Changed = false; + SmallVector NextWorklist; + + for (MachineInstr *I : Worklist) { + if (processInstr(I, MF, GR, MIB)) { + Changed = true; + } else { + NextWorklist.push_back(I); + } + } + Worklist = NextWorklist; + LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n"); + } + + if (!Worklist.empty()) { + LLVM_DEBUG(dbgs() << "Remaining worklist:\n"; + for (auto *I : Worklist) { I->dump(); }); + assert(Worklist.empty() && "Worklist is not empty"); } } @@ -159,6 +329,28 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) { processNewInstrs(MF, GR, MIB); + // TODO: Move this into is own function. + SmallVector ExtractInstrs; + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) { + ExtractInstrs.push_back(&MI); + } + } + } + + for (MachineInstr *MI : ExtractInstrs) { + MachineIRBuilder MIB(*MI); + Register Dst = MI->getOperand(0).getReg(); + Register Vec = MI->getOperand(1).getReg(); + Register Idx = MI->getOperand(2).getReg(); + + auto Intr = MIB.buildIntrinsic(Intrinsic::spv_extractelt, Dst, true, false); + Intr.addUse(Vec); + Intr.addUse(Idx); + + MI->eraseFromParent(); + } return true; } diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index db6f2d61e8f29..bb8bf443778f3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -192,6 +192,10 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB, .addUse(OpReg); } +// TODO: See if the comment needs to be more precise. This is a problem for more +// than just pointers. A bitcast between an two type that map to the same LLT +// will cause a problem. For example a bitcast from a float to an int. + // We do instruction selections early instead of calling MIB.buildBitcast() // generating the general op code G_BITCAST. When MachineVerifier validates // G_BITCAST we see a check of a kind: if Source Type is equal to Destination @@ -237,7 +241,7 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, SmallVector ToErase; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &MI : MBB) { - if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) && + if (/* !isSpvIntrinsic(MI, Intrinsic::spv_bitcast) && */ !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast)) continue; assert(MI.getOperand(2).isReg());