Skip to content

Commit 4c1d90c

Browse files
Add support for arbitrary integer with bitwidth larger than 64 bits in
spirv-backend
1 parent 9df1099 commit 4c1d90c

File tree

6 files changed

+52
-41
lines changed

6 files changed

+52
-41
lines changed

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,24 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
5050
unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16;
5151
const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
5252

53-
assert((NumVarOps == 1 || NumVarOps == 2) &&
53+
// we support integer up to 1024 bits
54+
assert((NumVarOps <= 1024) &&
5455
"Unsupported number of bits for literal variable");
5556

5657
O << ' ';
5758

58-
uint64_t Imm = MI->getOperand(StartIndex).getImm();
59-
60-
// Handle 64 bit literals.
61-
if (NumVarOps == 2) {
62-
Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32);
59+
// Handle arbitrary number of 32-bit words for the literal value.
60+
if (MI->getOpcode() == SPIRV::OpConstantI){
61+
APInt Val(NumVarOps * 32, 0);
62+
for (unsigned i = 0; i < NumVarOps; ++i) {
63+
Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) << (i * 32));
64+
}
65+
O << Val;
66+
return;
6367
}
6468

69+
uint64_t Imm = MI->getOperand(StartIndex).getImm();
70+
6571
// Format and print float values.
6672
if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
6773
APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
149149
}
150150

151151
unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
152-
if (Width > 64)
152+
if (Width > 1024)
153153
report_fatal_error("Unsupported integer width!");
154154
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
155155
if (ST.canUseExtension(
@@ -343,7 +343,7 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF,
343343
return Res;
344344
}
345345

346-
Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
346+
Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I,
347347
SPIRVType *SpvType,
348348
const SPIRVInstrInfo &TII,
349349
bool ZeroAsNull) {
@@ -353,10 +353,11 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
353353
if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
354354
MI->getOpcode() == SPIRV::OpConstantI))
355355
return MI->getOperand(0).getReg();
356-
return createConstInt(CI, I, SpvType, TII, ZeroAsNull);
356+
return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull);
357357
}
358358

359-
Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
359+
Register SPIRVGlobalRegistry::createConstInt(const Constant *CI,
360+
APInt Val,
360361
MachineInstr &I,
361362
SPIRVType *SpvType,
362363
const SPIRVInstrInfo &TII,
@@ -374,15 +375,15 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
374375
MachineInstrBuilder MIB;
375376
if (BitWidth == 1) {
376377
MIB = MIRBuilder
377-
.buildInstr(CI->isZero() ? SPIRV::OpConstantFalse
378+
.buildInstr(Val.isZero() ? SPIRV::OpConstantFalse
378379
: SPIRV::OpConstantTrue)
379380
.addDef(Res)
380381
.addUse(getSPIRVTypeID(SpvType));
381-
} else if (!CI->isZero() || !ZeroAsNull) {
382+
} else if (!Val.isZero() || !ZeroAsNull) {
382383
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
383384
.addDef(Res)
384385
.addUse(getSPIRVTypeID(SpvType));
385-
addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
386+
addNumImm(Val, MIB);
386387
} else {
387388
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
388389
.addDef(Res)
@@ -491,7 +492,7 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
491492
}
492493
assert(Type->getOpcode() == SPIRV::OpTypeInt);
493494
SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
494-
return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I,
495+
return getOrCreateConstInt(APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I,
495496
SpvBaseType, TII, ZeroAsNull);
496497
}
497498

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -515,10 +515,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
515515
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
516516
SPIRVType *SpvType, bool EmitIR,
517517
bool ZeroAsNull = true);
518-
Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
518+
Register getOrCreateConstInt(APInt Val, MachineInstr &I,
519519
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
520520
bool ZeroAsNull = true);
521-
Register createConstInt(const ConstantInt *CI, MachineInstr &I,
521+
Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I,
522522
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
523523
bool ZeroAsNull);
524524
Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2252,8 +2252,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
22522252
.addDef(AElt)
22532253
.addUse(GR.getSPIRVTypeID(ResType))
22542254
.addUse(X)
2255-
.addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
2256-
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
2255+
.addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
2256+
.addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
22572257
.constrainAllUses(TII, TRI, RBI);
22582258

22592259
// B[i]
@@ -2263,8 +2263,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
22632263
.addDef(BElt)
22642264
.addUse(GR.getSPIRVTypeID(ResType))
22652265
.addUse(Y)
2266-
.addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
2267-
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
2266+
.addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
2267+
.addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
22682268
.constrainAllUses(TII, TRI, RBI);
22692269

22702270
// A[i] * B[i]
@@ -2283,8 +2283,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
22832283
.addDef(MaskMul)
22842284
.addUse(GR.getSPIRVTypeID(ResType))
22852285
.addUse(Mul)
2286-
.addUse(GR.getOrCreateConstInt(0, I, EltType, TII, ZeroAsNull))
2287-
.addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
2286+
.addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, ZeroAsNull))
2287+
.addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
22882288
.constrainAllUses(TII, TRI, RBI);
22892289

22902290
// Acc = Acc + A[i] * B[i]
@@ -2381,7 +2381,7 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg,
23812381
auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
23822382
.addDef(ResVReg)
23832383
.addUse(GR.getSPIRVTypeID(ResType))
2384-
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I,
2384+
.addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
23852385
IntTy, TII, !STI.isShader()));
23862386

23872387
for (unsigned J = 2; J < I.getNumOperands(); J++) {
@@ -2405,7 +2405,7 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
24052405
TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
24062406
.addDef(ResVReg)
24072407
.addUse(GR.getSPIRVTypeID(ResType))
2408-
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy,
2408+
.addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy,
24092409
TII, !STI.isShader()))
24102410
.addImm(SPIRV::GroupOperation::Reduce)
24112411
.addUse(BallotReg)
@@ -2436,7 +2436,7 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
24362436
return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
24372437
.addDef(ResVReg)
24382438
.addUse(GR.getSPIRVTypeID(ResType))
2439-
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
2439+
.addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
24402440
!STI.isShader()))
24412441
.addImm(SPIRV::GroupOperation::Reduce)
24422442
.addUse(I.getOperand(2).getReg())
@@ -2463,7 +2463,7 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
24632463
return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
24642464
.addDef(ResVReg)
24652465
.addUse(GR.getSPIRVTypeID(ResType))
2466-
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
2466+
.addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
24672467
!STI.isShader()))
24682468
.addImm(SPIRV::GroupOperation::Reduce)
24692469
.addUse(I.getOperand(2).getReg());
@@ -2689,7 +2689,7 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
26892689
bool ZeroAsNull = !STI.isShader();
26902690
if (ResType->getOpcode() == SPIRV::OpTypeVector)
26912691
return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull);
2692-
return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
2692+
return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
26932693
}
26942694

26952695
Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
@@ -2720,7 +2720,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
27202720
AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0);
27212721
if (ResType->getOpcode() == SPIRV::OpTypeVector)
27222722
return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII);
2723-
return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII);
2723+
return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, ResType, TII);
27242724
}
27252725

27262726
bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
@@ -2939,8 +2939,7 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg,
29392939
Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I,
29402940
ResType, TII, !STI.isShader());
29412941
} else {
2942-
Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I,
2943-
ResType, TII, !STI.isShader());
2942+
Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, ResType, TII, !STI.isShader());
29442943
}
29452944
return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I);
29462945
}
@@ -3765,7 +3764,7 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
37653764
bool ZeroAsNull = !STI.isShader();
37663765
Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type));
37673766
Register ConstIntLastIdx = GR.getOrCreateConstInt(
3768-
ComponentCount - 1, I, BaseType, TII, ZeroAsNull);
3767+
APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, BaseType, TII, ZeroAsNull);
37693768

37703769
if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx},
37713770
SPIRV::OpVectorExtractDynamic))
@@ -3794,9 +3793,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
37943793
SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
37953794
bool ZeroAsNull = !STI.isShader();
37963795
Register ConstIntZero =
3797-
GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull);
3796+
GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), I, BaseType, TII, ZeroAsNull);
37983797
Register ConstIntOne =
3799-
GR.getOrCreateConstInt(1, I, BaseType, TII, ZeroAsNull);
3798+
GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), I, BaseType, TII, ZeroAsNull);
38003799

38013800
// SPIRV doesn't support vectors with more than 4 components. Since the
38023801
// algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only
@@ -3881,9 +3880,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
38813880

38823881
if (IsScalarRes) {
38833882
NegOneReg =
3884-
GR.getOrCreateConstInt((unsigned)-1, I, ResType, TII, ZeroAsNull);
3885-
Reg0 = GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
3886-
Reg32 = GR.getOrCreateConstInt(32, I, ResType, TII, ZeroAsNull);
3883+
GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, TII, ZeroAsNull);
3884+
Reg0 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
3885+
Reg32 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), I, ResType, TII, ZeroAsNull);
38873886
SelectOp = SPIRV::OpSelectSISCond;
38883887
AddOp = SPIRV::OpIAddS;
38893888
} else {

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,12 @@ void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
100100
if (Bitwidth == 16)
101101
MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
102102
return;
103-
} else if (Bitwidth <= 64) {
104-
uint64_t FullImm = Imm.getZExtValue();
105-
uint32_t LowBits = FullImm & 0xffffffff;
106-
uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
107-
MIB.addImm(LowBits).addImm(HighBits);
103+
} else if (Bitwidth <= 1024) {
104+
unsigned NumWords = (Bitwidth + 31) / 32;
105+
for (unsigned i = 0; i < NumWords; ++i) {
106+
uint32_t Word = Imm.extractBits(32, i * 32).getZExtValue();
107+
MIB.addImm(Word);
108+
}
108109
return;
109110
}
110111
report_fatal_error("Unsupported constant bitwidth");

llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ define i13 @getConstantI13() {
88
ret i13 42
99
}
1010

11+
define i96 @getConstantI96() {
12+
ret i96 18446744073709551620
13+
}
14+
1115
;; Capabilities:
1216
; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers"
1317
; CHECK-DAG: OpCapability ArbitraryPrecisionIntegersINTEL

0 commit comments

Comments
 (0)