-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[SPIRV] Add support for arbitrary-precision integers larger than 64 bits in SPIR-V backend #161270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[SPIRV] Add support for arbitrary-precision integers larger than 64 bits in SPIR-V backend #161270
Conversation
|
@llvm/pr-subscribers-backend-spir-v Author: None (YixingZhang007) Changesspirv-backend Full diff: https://github.com/llvm/llvm-project/pull/161270.diff 6 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index 776208bd3e693..dff9f699ebd6f 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -50,18 +50,24 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16;
const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
- assert((NumVarOps == 1 || NumVarOps == 2) &&
+ // we support integer up to 1024 bits
+ assert((NumVarOps <= 1024) &&
"Unsupported number of bits for literal variable");
O << ' ';
- uint64_t Imm = MI->getOperand(StartIndex).getImm();
-
- // Handle 64 bit literals.
- if (NumVarOps == 2) {
- Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32);
+ // Handle arbitrary number of 32-bit words for the literal value.
+ if (MI->getOpcode() == SPIRV::OpConstantI){
+ APInt Val(NumVarOps * 32, 0);
+ for (unsigned i = 0; i < NumVarOps; ++i) {
+ Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) << (i * 32));
+ }
+ O << Val;
+ return;
}
+ uint64_t Imm = MI->getOperand(StartIndex).getImm();
+
// Format and print float values.
if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 115766ce886c7..05b3371e97cdc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -149,7 +149,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
}
unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
- if (Width > 64)
+ if (Width > 1024)
report_fatal_error("Unsupported integer width!");
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
if (ST.canUseExtension(
@@ -343,7 +343,7 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF,
return Res;
}
-Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull) {
@@ -353,10 +353,11 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
MI->getOpcode() == SPIRV::OpConstantI))
return MI->getOperand(0).getReg();
- return createConstInt(CI, I, SpvType, TII, ZeroAsNull);
+ return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull);
}
-Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
+Register SPIRVGlobalRegistry::createConstInt(const Constant *CI,
+ APInt Val,
MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
@@ -374,15 +375,15 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
MachineInstrBuilder MIB;
if (BitWidth == 1) {
MIB = MIRBuilder
- .buildInstr(CI->isZero() ? SPIRV::OpConstantFalse
+ .buildInstr(Val.isZero() ? SPIRV::OpConstantFalse
: SPIRV::OpConstantTrue)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
- } else if (!CI->isZero() || !ZeroAsNull) {
+ } else if (!Val.isZero() || !ZeroAsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
- addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
+ addNumImm(Val, MIB);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
@@ -491,7 +492,7 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
}
assert(Type->getOpcode() == SPIRV::OpTypeInt);
SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
- return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I,
+ return getOrCreateConstInt(APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I,
SpvBaseType, TII, ZeroAsNull);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a648defa0a888..ee217f81fb416 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -515,10 +515,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
bool ZeroAsNull = true);
- Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+ Register getOrCreateConstInt(APInt Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
- Register createConstInt(const ConstantInt *CI, MachineInstr &I,
+ Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull);
Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 1aadd9df189a8..3e5566945ec0b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2252,8 +2252,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
.addDef(AElt)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(X)
- .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
- .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
.constrainAllUses(TII, TRI, RBI);
// B[i]
@@ -2263,8 +2263,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
.addDef(BElt)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Y)
- .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
- .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
.constrainAllUses(TII, TRI, RBI);
// A[i] * B[i]
@@ -2283,8 +2283,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
.addDef(MaskMul)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Mul)
- .addUse(GR.getOrCreateConstInt(0, I, EltType, TII, ZeroAsNull))
- .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
.constrainAllUses(TII, TRI, RBI);
// Acc = Acc + A[i] * B[i]
@@ -2381,7 +2381,7 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg,
auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I,
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
IntTy, TII, !STI.isShader()));
for (unsigned J = 2; J < I.getNumOperands(); J++) {
@@ -2405,7 +2405,7 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy,
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy,
TII, !STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(BallotReg)
@@ -2436,7 +2436,7 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
!STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(I.getOperand(2).getReg())
@@ -2463,7 +2463,7 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
!STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(I.getOperand(2).getReg());
@@ -2689,7 +2689,7 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
bool ZeroAsNull = !STI.isShader();
if (ResType->getOpcode() == SPIRV::OpTypeVector)
return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull);
- return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
+ return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
}
Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
@@ -2720,7 +2720,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0);
if (ResType->getOpcode() == SPIRV::OpTypeVector)
return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII);
- return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII);
+ return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, ResType, TII);
}
bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
@@ -2939,8 +2939,7 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg,
Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I,
ResType, TII, !STI.isShader());
} else {
- Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I,
- ResType, TII, !STI.isShader());
+ Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, ResType, TII, !STI.isShader());
}
return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I);
}
@@ -3765,7 +3764,7 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
bool ZeroAsNull = !STI.isShader();
Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type));
Register ConstIntLastIdx = GR.getOrCreateConstInt(
- ComponentCount - 1, I, BaseType, TII, ZeroAsNull);
+ APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, BaseType, TII, ZeroAsNull);
if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx},
SPIRV::OpVectorExtractDynamic))
@@ -3794,9 +3793,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
bool ZeroAsNull = !STI.isShader();
Register ConstIntZero =
- GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull);
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), I, BaseType, TII, ZeroAsNull);
Register ConstIntOne =
- GR.getOrCreateConstInt(1, I, BaseType, TII, ZeroAsNull);
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), I, BaseType, TII, ZeroAsNull);
// SPIRV doesn't support vectors with more than 4 components. Since the
// algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only
@@ -3881,9 +3880,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
if (IsScalarRes) {
NegOneReg =
- GR.getOrCreateConstInt((unsigned)-1, I, ResType, TII, ZeroAsNull);
- Reg0 = GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
- Reg32 = GR.getOrCreateConstInt(32, I, ResType, TII, ZeroAsNull);
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, TII, ZeroAsNull);
+ Reg0 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
+ Reg32 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), I, ResType, TII, ZeroAsNull);
SelectOp = SPIRV::OpSelectSISCond;
AddOp = SPIRV::OpIAddS;
} else {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 820e56b362edc..e409234a83568 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -100,11 +100,12 @@ void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
if (Bitwidth == 16)
MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
return;
- } else if (Bitwidth <= 64) {
- uint64_t FullImm = Imm.getZExtValue();
- uint32_t LowBits = FullImm & 0xffffffff;
- uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
- MIB.addImm(LowBits).addImm(HighBits);
+ } else if (Bitwidth <= 1024) {
+ unsigned NumWords = (Bitwidth + 31) / 32;
+ for (unsigned i = 0; i < NumWords; ++i) {
+ uint32_t Word = Imm.extractBits(32, i * 32).getZExtValue();
+ MIB.addImm(Word);
+ }
return;
}
report_fatal_error("Unsupported constant bitwidth");
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
index 41d4b58ed1157..17ba9b044842c 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
@@ -8,6 +8,10 @@ define i13 @getConstantI13() {
ret i13 42
}
+define i96 @getConstantI96() {
+ ret i96 18446744073709551620
+}
+
;; Capabilities:
; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers"
; CHECK-DAG: OpCapability ArbitraryPrecisionIntegersINTEL
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
This PR has been completed and cleaned up, but I will keep it as a draft for now. The issue was originally discovered when |
This PR extends SPIR-V code generation in LLVM to support arbitrary precision integer up to 1024 bits, enabled by the
SPV_INTEL_arbitrary_precision_integersextension. More specifically, the following changes are made.getOrCreateConstIntandcreateConstIntfunctions in theSPIRVGlobalRegistrypass now accept anAPInt Valparameter instead ofuint64_t Val. All relevant call sites inSPIRVInstructionSelector.cppare updated corresponding.