Skip to content

Commit 15eaf6e

Browse files
committed
use iterative approach
1 parent 2dfcd27 commit 15eaf6e

File tree

3 files changed

+161
-165
lines changed

3 files changed

+161
-165
lines changed

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 45 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -3178,100 +3178,74 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
31783178
Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
31793179
Register SrcReg, unsigned BitSetOpcode, bool SwapPrimarySide) const {
31803180

3181-
unsigned ComponentCount = GR.getScalarOrVectorComponentCount(ResType);
31823181
// SPIR-V only allow vecs of size 2,3,4. Calling with a larger vec requires
3183-
// creating a return type with an invalid vec size. If that is resolved
3184-
// then this function is valid up to vec8 as the intermediate splitting
3185-
// would create 2 vec4.
3182+
// creating a param reg and return reg with an invalid vec size. If that is
3183+
// resolved then this function is valid for vectors of any component size.
3184+
unsigned ComponentCount = GR.getScalarOrVectorComponentCount(ResType);
31863185
assert(ComponentCount < 5 && "Vec 5+ will generate invalid SPIR-V ops");
31873186

3188-
3189-
SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
31903187
bool ZeroAsNull = STI.isOpenCLEnv();
3191-
Register ConstIntZero =
3192-
GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull);
3193-
unsigned LeftComponentCount = ComponentCount / 2;
3194-
unsigned RightComponentCount = ComponentCount - LeftComponentCount;
3195-
bool LeftIsVector = LeftComponentCount > 1;
3196-
3197-
// Split the SrcReg in half into 2 smaller vec registers
3198-
// (ie i64x4 -> i64x2, i64x2)
31993188
MachineIRBuilder MIRBuilder(I);
3200-
SPIRVType *OpType = GR.getOrCreateSPIRVIntegerType(64, MIRBuilder);
3201-
SPIRVType *LeftOpType = OpType;
3202-
SPIRVType *LeftResType = BaseType;
3203-
if (LeftIsVector) {
3204-
LeftOpType =
3205-
GR.getOrCreateSPIRVVectorType(OpType, LeftComponentCount, MIRBuilder);
3206-
LeftResType =
3207-
GR.getOrCreateSPIRVVectorType(BaseType, LeftComponentCount, MIRBuilder);
3208-
}
3209-
3210-
SPIRVType *RightOpType =
3211-
GR.getOrCreateSPIRVVectorType(OpType, RightComponentCount, MIRBuilder);
3212-
SPIRVType *RightResType =
3213-
GR.getOrCreateSPIRVVectorType(BaseType, RightComponentCount, MIRBuilder);
3214-
3215-
Register LeftSideIn = MRI->createVirtualRegister(GR.getRegClass(LeftOpType));
3216-
Register RightSideIn =
3217-
MRI->createVirtualRegister(GR.getRegClass(RightOpType));
3218-
3219-
// Extract the left half from the SrcReg into LeftSideIn
3220-
// accounting for the special case when it only has one element
3221-
if (LeftIsVector) {
3189+
SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
3190+
SPIRVType *I64Type = GR.getOrCreateSPIRVIntegerType(64, MIRBuilder);
3191+
SPIRVType *I64x2Type = GR.getOrCreateSPIRVVectorType(I64Type, 2, MIRBuilder);
3192+
SPIRVType *Vec2ResType =
3193+
GR.getOrCreateSPIRVVectorType(BaseType, 2, MIRBuilder);
3194+
3195+
std::vector<Register> PartialRegs;
3196+
3197+
// Loops 0, 2, 4, ... but stops one loop early when ComponentCount is odd
3198+
unsigned CurrentComponent = 0;
3199+
for (; CurrentComponent + 1 < ComponentCount; CurrentComponent += 2) {
3200+
Register SubVecReg = MRI->createVirtualRegister(GR.getRegClass(I64x2Type));
3201+
32223202
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
32233203
TII.get(SPIRV::OpVectorShuffle))
3224-
.addDef(LeftSideIn)
3225-
.addUse(GR.getSPIRVTypeID(LeftOpType))
3204+
.addDef(SubVecReg)
3205+
.addUse(GR.getSPIRVTypeID(I64x2Type))
32263206
.addUse(SrcReg)
32273207
// Per the spec, repeat the vector if only one vec is needed
32283208
.addUse(SrcReg);
32293209

3230-
for (unsigned J = 0; J < LeftComponentCount; ++J)
3231-
MIB.addImm(J);
3210+
MIB.addImm(CurrentComponent);
3211+
MIB.addImm(CurrentComponent + 1);
32323212

32333213
if (!MIB.constrainAllUses(TII, TRI, RBI))
32343214
return false;
32353215

3236-
} else {
3237-
if (!selectOpWithSrcs(LeftSideIn, LeftOpType, I, {SrcReg, ConstIntZero},
3238-
SPIRV::OpVectorExtractDynamic))
3216+
Register SubVecBitSetReg =
3217+
MRI->createVirtualRegister(GR.getRegClass(Vec2ResType));
3218+
3219+
if (!selectFirstBitSet64(SubVecBitSetReg, Vec2ResType, I, SubVecReg,
3220+
BitSetOpcode, SwapPrimarySide))
32393221
return false;
3222+
3223+
PartialRegs.push_back(SubVecBitSetReg);
32403224
}
32413225

3242-
// Extract the right half from the SrcReg into RightSideIn.
3243-
// Right will always be a vector since the only time one element is left is
3244-
// when Component == 3, and in that case Left is one element.
3245-
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
3246-
TII.get(SPIRV::OpVectorShuffle))
3247-
.addDef(RightSideIn)
3248-
.addUse(GR.getSPIRVTypeID(RightOpType))
3249-
.addUse(SrcReg)
3250-
// Per the spec, repeat the vector if only one vec is needed
3251-
.addUse(SrcReg);
3226+
// On odd component counts we need to handle one more component
3227+
if (CurrentComponent != ComponentCount) {
3228+
Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type));
3229+
Register ConstIntLastIdx = GR.getOrCreateConstInt(
3230+
ComponentCount - 1, I, BaseType, TII, ZeroAsNull);
32523231

3253-
for (unsigned J = LeftComponentCount; J < ComponentCount; ++J)
3254-
MIB.addImm(J);
3232+
if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx},
3233+
SPIRV::OpVectorExtractDynamic))
3234+
return false;
32553235

3256-
if (!MIB.constrainAllUses(TII, TRI, RBI))
3257-
return false;
3236+
Register FinalElemBitSetReg =
3237+
MRI->createVirtualRegister(GR.getRegClass(BaseType));
32583238

3259-
// Recursively call selectFirstBitSet64 on the 2 halves
3260-
Register LeftSideOut =
3261-
MRI->createVirtualRegister(GR.getRegClass(LeftResType));
3262-
Register RightSideOut =
3263-
MRI->createVirtualRegister(GR.getRegClass(RightResType));
3239+
if (!selectFirstBitSet64(FinalElemBitSetReg, BaseType, I, FinalElemReg,
3240+
BitSetOpcode, SwapPrimarySide))
3241+
return false;
32643242

3265-
if (!selectFirstBitSet64(LeftSideOut, LeftResType, I, LeftSideIn,
3266-
BitSetOpcode, SwapPrimarySide))
3267-
return false;
3268-
if (!selectFirstBitSet64(RightSideOut, RightResType, I, RightSideIn,
3269-
BitSetOpcode, SwapPrimarySide))
3270-
return false;
3243+
PartialRegs.push_back(FinalElemBitSetReg);
3244+
}
32713245

3272-
// Join the two resulting registers back into the return type
3273-
// (ie i32x2, i32x2 -> i32x4)
3274-
return selectOpWithSrcs(ResVReg, ResType, I, {LeftSideOut, RightSideOut},
3246+
// Join all the resulting registers back into the return type in order
3247+
// (ie i32x2, i32x2, i32x1 -> i32x5)
3248+
return selectOpWithSrcs(ResVReg, ResType, I, PartialRegs,
32753249
SPIRV::OpCompositeConstruct);
32763250
}
32773251

llvm/test/CodeGen/SPIRV/hlsl-intrinsics/firstbithigh.ll

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
; CHECK-DAG: [[u32x3_t:%.+]] = OpTypeVector [[u32_t]] 3
99
; CHECK-DAG: [[u32x4_t:%.+]] = OpTypeVector [[u32_t]] 4
1010
; CHECK-DAG: [[const_0:%.*]] = OpConstant [[u32_t]] 0
11+
; CHECK-DAG: [[const_2:%.*]] = OpConstant [[u32_t]] 2
1112
; CHECK-DAG: [[const_0x2:%.*]] = OpConstantComposite [[u32x2_t]] [[const_0]] [[const_0]]
1213
; CHECK-DAG: [[const_1:%.*]] = OpConstant [[u32_t]] 1
1314
; CHECK-DAG: [[const_32:%.*]] = OpConstant [[u32_t]] 32
@@ -146,32 +147,37 @@ entry:
146147
; CHECK-LABEL: Begin function firstbituhigh_v3xi64
147148
define noundef <3 x i32> @firstbituhigh_v3xi64(<3 x i64> noundef %a) {
148149
entry:
149-
; Split the i64x3 into i64, i64x2
150+
; Preamble
150151
; CHECK: [[a:%.+]] = OpFunctionParameter [[u64x3_t]]
151-
; CHECK: [[left:%.+]] = OpVectorExtractDynamic [[u64_t]] [[a]] [[const_0]]
152-
; CHECK: [[right:%.+]] = OpVectorShuffle [[u64x2_t]] [[a]] [[a]] 1 2
153152

154-
; Do firstbituhigh on i64, i64x2
155-
; CHECK: [[left_cast:%.+]] = OpBitcast [[u32x2_t]] [[left]]
156-
; CHECK: [[left_lsb_bits:%.+]] = OpExtInst [[u32x2_t]] [[glsl_450_ext]] FindUMsb [[left_cast]]
157-
; CHECK: [[left_high_bits:%.+]] = OpVectorExtractDynamic [[u32_t]] [[left_lsb_bits]] [[const_0]]
158-
; CHECK: [[left_low_bits:%.+]] = OpVectorExtractDynamic [[u32_t]] [[left_lsb_bits]] [[const_1]]
159-
; CHECK: [[left_should_use_low:%.+]] = OpIEqual [[bool_t]] [[left_high_bits]] [[const_neg1]]
160-
; CHECK: [[left_ans_bits:%.+]] = OpSelect [[u32_t]] [[left_should_use_low]] [[left_low_bits]] [[left_high_bits]]
161-
; CHECK: [[left_ans_offset:%.+]] = OpSelect [[u32_t]] [[left_should_use_low]] [[const_0]] [[const_32]]
162-
; CHECK: [[left_res:%.+]] = OpIAdd [[u32_t]] [[left_ans_offset]] [[left_ans_bits]]
153+
; Extract first 2 components from %a
154+
; CHECK: [[pt1:%.+]] = OpVectorShuffle [[u64x2_t]] [[a]] [[a]] 0 1
163155

164-
; CHECK: [[right_cast:%.+]] = OpBitcast [[u32x4_t]] [[right]]
165-
; CHECK: [[right_lsb_bits:%.+]] = OpExtInst [[u32x4_t]] [[glsl_450_ext]] FindUMsb [[right_cast]]
166-
; CHECK: [[right_high_bits:%.+]] = OpVectorShuffle [[u32x2_t]] [[right_lsb_bits]] [[right_lsb_bits]] 0 2
167-
; CHECK: [[right_low_bits:%.+]] = OpVectorShuffle [[u32x2_t]] [[right_lsb_bits]] [[right_lsb_bits]] 1 3
168-
; CHECK: [[right_should_use_low:%.+]] = OpIEqual [[boolx2_t]] [[right_high_bits]] [[const_neg1x2]]
169-
; CHECK: [[right_ans_bits:%.+]] = OpSelect [[u32x2_t]] [[right_should_use_low]] [[right_low_bits]] [[right_high_bits]]
170-
; CHECK: [[right_ans_offset:%.+]] = OpSelect [[u32x2_t]] [[right_should_use_low]] [[const_0x2]] [[const_32x2]]
171-
; CHECK: [[right_res:%.+]] = OpIAdd [[u32x2_t]] [[right_ans_offset]] [[right_ans_bits]]
156+
; Do firstbituhigh on the first 2 components
157+
; CHECK: [[pt1_cast:%.+]] = OpBitcast [[u32x4_t]] [[pt1]]
158+
; CHECK: [[pt1_lsb_bits:%.+]] = OpExtInst [[u32x4_t]] [[glsl_450_ext]] FindUMsb [[pt1_cast]]
159+
; CHECK: [[pt1_high_bits:%.+]] = OpVectorShuffle [[u32x2_t]] [[pt1_lsb_bits]] [[pt1_lsb_bits]] 0 2
160+
; CHECK: [[pt1_low_bits:%.+]] = OpVectorShuffle [[u32x2_t]] [[pt1_lsb_bits]] [[pt1_lsb_bits]] 1 3
161+
; CHECK: [[pt1_should_use_low:%.+]] = OpIEqual [[boolx2_t]] [[pt1_high_bits]] [[const_neg1x2]]
162+
; CHECK: [[pt1_ans_bits:%.+]] = OpSelect [[u32x2_t]] [[pt1_should_use_low]] [[pt1_low_bits]] [[pt1_high_bits]]
163+
; CHECK: [[pt1_ans_offset:%.+]] = OpSelect [[u32x2_t]] [[pt1_should_use_low]] [[const_0x2]] [[const_32x2]]
164+
; CHECK: [[pt1_res:%.+]] = OpIAdd [[u32x2_t]] [[pt1_ans_offset]] [[pt1_ans_bits]]
172165

173-
; Merge the resulting i32, i32x2 into the final i32x3 and return it
174-
; CHECK: [[ret:%.+]] = OpCompositeConstruct [[u32x3_t]] [[left_res]] [[right_res]]
166+
; Extract the last component from %a
167+
; CHECK: [[pt2:%.+]] = OpVectorExtractDynamic [[u64_t]] [[a]] [[const_2]]
168+
169+
; Do firstbituhigh on the last component
170+
; CHECK: [[pt2_cast:%.+]] = OpBitcast [[u32x2_t]] [[pt2]]
171+
; CHECK: [[pt2_lsb_bits:%.+]] = OpExtInst [[u32x2_t]] [[glsl_450_ext]] FindUMsb [[pt2_cast]]
172+
; CHECK: [[pt2_high_bits:%.+]] = OpVectorExtractDynamic [[u32_t]] [[pt2_lsb_bits]] [[const_0]]
173+
; CHECK: [[pt2_low_bits:%.+]] = OpVectorExtractDynamic [[u32_t]] [[pt2_lsb_bits]] [[const_1]]
174+
; CHECK: [[pt2_should_use_low:%.+]] = OpIEqual [[bool_t]] [[pt2_high_bits]] [[const_neg1]]
175+
; CHECK: [[pt2_ans_bits:%.+]] = OpSelect [[u32_t]] [[pt2_should_use_low]] [[pt2_low_bits]] [[pt2_high_bits]]
176+
; CHECK: [[pt2_ans_offset:%.+]] = OpSelect [[u32_t]] [[pt2_should_use_low]] [[const_0]] [[const_32]]
177+
; CHECK: [[pt2_res:%.+]] = OpIAdd [[u32_t]] [[pt2_ans_offset]] [[pt2_ans_bits]]
178+
179+
; Merge the parts into the final i32x3 and return it
180+
; CHECK: [[ret:%.+]] = OpCompositeConstruct [[u32x3_t]] [[pt1_res]] [[pt2_res]]
175181
; CHECK: OpReturnValue [[ret]]
176182
%elt.firstbituhigh = call <3 x i32> @llvm.spv.firstbituhigh.v3i64(<3 x i64> %a)
177183
ret <3 x i32> %elt.firstbituhigh
@@ -180,32 +186,37 @@ entry:
180186
; CHECK-LABEL: Begin function firstbituhigh_v4xi64
181187
define noundef <4 x i32> @firstbituhigh_v4xi64(<4 x i64> noundef %a) {
182188
entry:
183-
; Split the i64x4 into 2 i64x2
189+
; Preamble
184190
; CHECK: [[a:%.+]] = OpFunctionParameter [[u64x4_t]]
185-
; CHECK: [[left:%.+]] = OpVectorShuffle [[u64x2_t]] [[a]] [[a]] 0 1
186-
; CHECK: [[right:%.+]] = OpVectorShuffle [[u64x2_t]] [[a]] [[a]] 2 3
187191

188-
; Do firstbithigh on the 2 i64x2
189-
; CHECK: [[left_cast:%.+]] = OpBitcast [[u32x4_t]] [[left]]
190-
; CHECK: [[left_lsb_bits:%.+]] = OpExtInst [[u32x4_t]] [[glsl_450_ext]] FindUMsb [[left_cast]]
191-
; CHECK: [[left_high_bits:%.+]] = OpVectorShuffle [[u32x2_t]] [[left_lsb_bits]] [[left_lsb_bits]] 0 2
192-
; CHECK: [[left_low_bits:%.+]] = OpVectorShuffle [[u32x2_t]] [[left_lsb_bits]] [[left_lsb_bits]] 1 3
193-
; CHECK: [[left_should_use_low:%.+]] = OpIEqual [[boolx2_t]] [[left_high_bits]] [[const_neg1x2]]
194-
; CHECK: [[left_ans_bits:%.+]] = OpSelect [[u32x2_t]] [[left_should_use_low]] [[left_low_bits]] [[left_high_bits]]
195-
; CHECK: [[left_ans_offset:%.+]] = OpSelect [[u32x2_t]] [[left_should_use_low]] [[const_0x2]] [[const_32x2]]
196-
; CHECK: [[left_res:%.+]] = OpIAdd [[u32x2_t]] [[left_ans_offset]] [[left_ans_bits]]
192+
; Extract first 2 components from %a
193+
; CHECK: [[pt1:%.+]] = OpVectorShuffle [[u64x2_t]] [[a]] [[a]] 0 1
194+
195+
; Do firstbituhigh on the first 2 components
196+
; CHECK: [[pt1_cast:%.+]] = OpBitcast [[u32x4_t]] [[pt1]]
197+
; CHECK: [[pt1_lsb_bits:%.+]] = OpExtInst [[u32x4_t]] [[glsl_450_ext]] FindUMsb [[pt1_cast]]
198+
; CHECK: [[pt1_high_bits:%.+]] = OpVectorShuffle [[u32x2_t]] [[pt1_lsb_bits]] [[pt1_lsb_bits]] 0 2
199+
; CHECK: [[pt1_low_bits:%.+]] = OpVectorShuffle [[u32x2_t]] [[pt1_lsb_bits]] [[pt1_lsb_bits]] 1 3
200+
; CHECK: [[pt1_should_use_low:%.+]] = OpIEqual [[boolx2_t]] [[pt1_high_bits]] [[const_neg1x2]]
201+
; CHECK: [[pt1_ans_bits:%.+]] = OpSelect [[u32x2_t]] [[pt1_should_use_low]] [[pt1_low_bits]] [[pt1_high_bits]]
202+
; CHECK: [[pt1_ans_offset:%.+]] = OpSelect [[u32x2_t]] [[pt1_should_use_low]] [[const_0x2]] [[const_32x2]]
203+
; CHECK: [[pt1_res:%.+]] = OpIAdd [[u32x2_t]] [[pt1_ans_offset]] [[pt1_ans_bits]]
204+
205+
; Extract last 2 components from %a
206+
; CHECK: [[pt2:%.+]] = OpVectorShuffle [[u64x2_t]] [[a]] [[a]] 2 3
197207

198-
; CHECK: [[right_cast:%.+]] = OpBitcast [[u32x4_t]] [[right]]
199-
; CHECK: [[right_lsb_bits:%.+]] = OpExtInst [[u32x4_t]] [[glsl_450_ext]] FindUMsb [[right_cast]]
200-
; CHECK: [[right_high_bits:%.+]] = OpVectorShuffle [[u32x2_t]] [[right_lsb_bits]] [[right_lsb_bits]] 0 2
201-
; CHECK: [[right_low_bits:%.+]] = OpVectorShuffle [[u32x2_t]] [[right_lsb_bits]] [[right_lsb_bits]] 1 3
202-
; CHECK: [[right_should_use_low:%.+]] = OpIEqual [[boolx2_t]] [[right_high_bits]] [[const_neg1x2]]
203-
; CHECK: [[right_ans_bits:%.+]] = OpSelect [[u32x2_t]] [[right_should_use_low]] [[right_low_bits]] [[right_high_bits]]
204-
; CHECK: [[right_ans_offset:%.+]] = OpSelect [[u32x2_t]] [[right_should_use_low]] [[const_0x2]] [[const_32x2]]
205-
; CHECK: [[right_res:%.+]] = OpIAdd [[u32x2_t]] [[right_ans_offset]] [[right_ans_bits]]
208+
; Do firstbituhigh on the last 2 components
209+
; CHECK: [[pt2_cast:%.+]] = OpBitcast [[u32x4_t]] [[pt2]]
210+
; CHECK: [[pt2_lsb_bits:%.+]] = OpExtInst [[u32x4_t]] [[glsl_450_ext]] FindUMsb [[pt2_cast]]
211+
; CHECK: [[pt2_high_bits:%.+]] = OpVectorShuffle [[u32x2_t]] [[pt2_lsb_bits]] [[pt2_lsb_bits]] 0 2
212+
; CHECK: [[pt2_low_bits:%.+]] = OpVectorShuffle [[u32x2_t]] [[pt2_lsb_bits]] [[pt2_lsb_bits]] 1 3
213+
; CHECK: [[pt2_should_use_low:%.+]] = OpIEqual [[boolx2_t]] [[pt2_high_bits]] [[const_neg1x2]]
214+
; CHECK: [[pt2_ans_bits:%.+]] = OpSelect [[u32x2_t]] [[pt2_should_use_low]] [[pt2_low_bits]] [[pt2_high_bits]]
215+
; CHECK: [[pt2_ans_offset:%.+]] = OpSelect [[u32x2_t]] [[pt2_should_use_low]] [[const_0x2]] [[const_32x2]]
216+
; CHECK: [[pt2_res:%.+]] = OpIAdd [[u32x2_t]] [[pt2_ans_offset]] [[pt2_ans_bits]]
206217

207-
; Merge the resulting 2 i32x2 into the final i32x4 and return it
208-
; CHECK: [[ret:%.+]] = OpCompositeConstruct [[u32x4_t]] [[left_res]] [[right_res]]
218+
; Merge the parts into the final i32x4 and return it
219+
; CHECK: [[ret:%.+]] = OpCompositeConstruct [[u32x4_t]] [[pt1_res]] [[pt2_res]]
209220
; CHECK: OpReturnValue [[ret]]
210221
%elt.firstbituhigh = call <4 x i32> @llvm.spv.firstbituhigh.v4i64(<4 x i64> %a)
211222
ret <4 x i32> %elt.firstbituhigh

0 commit comments

Comments
 (0)