Skip to content

Commit 1693d8e

Browse files
authored
[AArch64][SelectionDAG] Vector splitting and promotion for histogram intrinsic (#103037)
Adds support for wider-than-legal vector types for the histogram intrinsic (llvm.experimental.vector.histogram.add) by splitting the vector. Also adds integer promotion for the Inc operand.
1 parent e0fa2f1 commit 1693d8e

File tree

6 files changed

+148
-7
lines changed

6 files changed

+148
-7
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
12411241
Action = TLI.getOperationAction(Node->getOpcode(),
12421242
Node->getOperand(0).getValueType());
12431243
break;
1244+
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
1245+
Action = TLI.getOperationAction(
1246+
Node->getOpcode(),
1247+
cast<MaskedHistogramSDNode>(Node)->getIndex().getValueType());
1248+
break;
12441249
default:
12451250
if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
12461251
Action = TLI.getCustomOperationAction(*Node);

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,6 +2043,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
20432043
case ISD::EXPERIMENTAL_VP_SPLICE:
20442044
Res = PromoteIntOp_VP_SPLICE(N, OpNo);
20452045
break;
2046+
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
2047+
Res = PromoteIntOp_VECTOR_HISTOGRAM(N, OpNo);
2048+
break;
20462049
}
20472050

20482051
// If the result is null, the sub-method took care of registering results etc.
@@ -2755,6 +2758,14 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo) {
27552758
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
27562759
}
27572760

2761+
SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N,
2762+
unsigned OpNo) {
2763+
assert(OpNo == 1 && "Unexpected operand for promotion");
2764+
SmallVector<SDValue, 7> NewOps(N->ops());
2765+
NewOps[1] = GetPromotedInteger(N->getOperand(1));
2766+
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
2767+
}
2768+
27582769
//===----------------------------------------------------------------------===//
27592770
// Integer Result Expansion
27602771
//===----------------------------------------------------------------------===//

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
425425
SDValue PromoteIntOp_PATCHPOINT(SDNode *N, unsigned OpNo);
426426
SDValue PromoteIntOp_VP_STRIDED(SDNode *N, unsigned OpNo);
427427
SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
428+
SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
428429

429430
void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
430431
void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
@@ -982,6 +983,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
982983
SDValue SplitVecOp_CMP(SDNode *N);
983984
SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
984985
SDValue SplitVecOp_VP_CttzElements(SDNode *N);
986+
SDValue SplitVecOp_VECTOR_HISTOGRAM(SDNode *N);
985987

986988
//===--------------------------------------------------------------------===//
987989
// Vector Widening Support: LegalizeVectorTypes.cpp

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3355,6 +3355,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
33553355
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
33563356
Res = SplitVecOp_VP_CttzElements(N);
33573357
break;
3358+
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
3359+
Res = SplitVecOp_VECTOR_HISTOGRAM(N);
3360+
break;
33583361
}
33593362

33603363
// If the result is null, the sub-method took care of registering results etc.
@@ -4374,6 +4377,28 @@ SDValue DAGTypeLegalizer::SplitVecOp_VP_CttzElements(SDNode *N) {
43744377
DAG.getNode(ISD::ADD, DL, ResVT, VLo, ResHi));
43754378
}
43764379

4380+
SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
4381+
MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
4382+
SDLoc DL(HG);
4383+
SDValue Inc = HG->getInc();
4384+
SDValue Ptr = HG->getBasePtr();
4385+
SDValue Scale = HG->getScale();
4386+
SDValue IntID = HG->getIntID();
4387+
EVT MemVT = HG->getMemoryVT();
4388+
MachineMemOperand *MMO = HG->getMemOperand();
4389+
ISD::MemIndexType IndexType = HG->getIndexType();
4390+
4391+
SDValue IndexLo, IndexHi, MaskLo, MaskHi;
4392+
std::tie(IndexLo, IndexHi) = DAG.SplitVector(HG->getIndex(), DL);
4393+
std::tie(MaskLo, MaskHi) = DAG.SplitVector(HG->getMask(), DL);
4394+
SDValue OpsLo[] = {HG->getChain(), Inc, MaskLo, Ptr, IndexLo, Scale, IntID};
4395+
SDValue Lo = DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL,
4396+
OpsLo, MMO, IndexType);
4397+
SDValue OpsHi[] = {Lo, Inc, MaskHi, Ptr, IndexHi, Scale, IntID};
4398+
return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, OpsHi,
4399+
MMO, IndexType);
4400+
}
4401+
43774402
//===----------------------------------------------------------------------===//
43784403
// Result Vector Widening
43794404
//===----------------------------------------------------------------------===//

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,10 +1790,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
17901790

17911791
// Histcnt is SVE2 only
17921792
if (Subtarget->hasSVE2()) {
1793-
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
1793+
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv4i32,
1794+
Custom);
1795+
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
17941796
Custom);
1795-
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i8, Custom);
1796-
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i16, Custom);
17971797
}
17981798
}
17991799

@@ -28550,11 +28550,10 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2855028550
assert(CID->getZExtValue() == Intrinsic::experimental_vector_histogram_add &&
2855128551
"Unexpected histogram update operation");
2855228552

28553-
EVT IncVT = Inc.getValueType();
2855428553
EVT IndexVT = Index.getValueType();
2855528554
LLVMContext &Ctx = *DAG.getContext();
2855628555
ElementCount EC = IndexVT.getVectorElementCount();
28557-
EVT MemVT = EVT::getVectorVT(Ctx, IncVT, EC);
28556+
EVT MemVT = EVT::getVectorVT(Ctx, HG->getMemoryVT(), EC);
2855828557
EVT IncExtVT =
2855928558
EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
2856028559
EVT IncSplatVT = EVT::getVectorVT(Ctx, IncExtVT, EC);

llvm/test/CodeGen/AArch64/sve2-histcnt.ll

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,10 @@ define void @histogram_i16_literal_1(ptr %base, <vscale x 4 x i32> %indices, <vs
132132
; CHECK-LABEL: histogram_i16_literal_1:
133133
; CHECK: // %bb.0:
134134
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
135+
; CHECK-NEXT: mov z3.s, #1 // =0x1
135136
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
136-
; CHECK-NEXT: add z1.s, z2.s, z1.s
137+
; CHECK-NEXT: ptrue p1.s
138+
; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
137139
; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
138140
; CHECK-NEXT: ret
139141
%buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
@@ -145,8 +147,10 @@ define void @histogram_i16_literal_2(ptr %base, <vscale x 4 x i32> %indices, <vs
145147
; CHECK-LABEL: histogram_i16_literal_2:
146148
; CHECK: // %bb.0:
147149
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
150+
; CHECK-NEXT: mov z3.s, #2 // =0x2
148151
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
149-
; CHECK-NEXT: adr z1.s, [z2.s, z1.s, lsl #1]
152+
; CHECK-NEXT: ptrue p1.s
153+
; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
150154
; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
151155
; CHECK-NEXT: ret
152156
%buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
@@ -169,4 +173,99 @@ define void @histogram_i16_literal_3(ptr %base, <vscale x 4 x i32> %indices, <vs
169173
ret void
170174
}
171175

176+
define void @histogram_i64_4_lane(<vscale x 4 x ptr> %buckets, i64 %inc, <vscale x 4 x i1> %mask) #0 {
177+
; CHECK-LABEL: histogram_i64_4_lane:
178+
; CHECK: // %bb.0:
179+
; CHECK-NEXT: punpklo p1.h, p0.b
180+
; CHECK-NEXT: mov z4.d, x0
181+
; CHECK-NEXT: ptrue p2.d
182+
; CHECK-NEXT: histcnt z2.d, p1/z, z0.d, z0.d
183+
; CHECK-NEXT: ld1d { z3.d }, p1/z, [z0.d]
184+
; CHECK-NEXT: punpkhi p0.h, p0.b
185+
; CHECK-NEXT: mad z2.d, p2/m, z4.d, z3.d
186+
; CHECK-NEXT: st1d { z2.d }, p1, [z0.d]
187+
; CHECK-NEXT: histcnt z0.d, p0/z, z1.d, z1.d
188+
; CHECK-NEXT: ld1d { z2.d }, p0/z, [z1.d]
189+
; CHECK-NEXT: mad z0.d, p2/m, z4.d, z2.d
190+
; CHECK-NEXT: st1d { z0.d }, p0, [z1.d]
191+
; CHECK-NEXT: ret
192+
call void @llvm.experimental.vector.histogram.add.nxv4p0.i64(<vscale x 4 x ptr> %buckets, i64 %inc, <vscale x 4 x i1> %mask)
193+
ret void
194+
}
195+
196+
define void @histogram_i64_8_lane(<vscale x 8 x ptr> %buckets, i64 %inc, <vscale x 8 x i1> %mask) #0 {
197+
; CHECK-LABEL: histogram_i64_8_lane:
198+
; CHECK: // %bb.0:
199+
; CHECK-NEXT: punpklo p2.h, p0.b
200+
; CHECK-NEXT: mov z6.d, x0
201+
; CHECK-NEXT: ptrue p1.d
202+
; CHECK-NEXT: punpklo p3.h, p2.b
203+
; CHECK-NEXT: punpkhi p2.h, p2.b
204+
; CHECK-NEXT: histcnt z4.d, p3/z, z0.d, z0.d
205+
; CHECK-NEXT: ld1d { z5.d }, p3/z, [z0.d]
206+
; CHECK-NEXT: punpkhi p0.h, p0.b
207+
; CHECK-NEXT: mad z4.d, p1/m, z6.d, z5.d
208+
; CHECK-NEXT: st1d { z4.d }, p3, [z0.d]
209+
; CHECK-NEXT: histcnt z0.d, p2/z, z1.d, z1.d
210+
; CHECK-NEXT: ld1d { z4.d }, p2/z, [z1.d]
211+
; CHECK-NEXT: mad z0.d, p1/m, z6.d, z4.d
212+
; CHECK-NEXT: st1d { z0.d }, p2, [z1.d]
213+
; CHECK-NEXT: punpklo p2.h, p0.b
214+
; CHECK-NEXT: punpkhi p0.h, p0.b
215+
; CHECK-NEXT: histcnt z0.d, p2/z, z2.d, z2.d
216+
; CHECK-NEXT: ld1d { z1.d }, p2/z, [z2.d]
217+
; CHECK-NEXT: mad z0.d, p1/m, z6.d, z1.d
218+
; CHECK-NEXT: st1d { z0.d }, p2, [z2.d]
219+
; CHECK-NEXT: histcnt z0.d, p0/z, z3.d, z3.d
220+
; CHECK-NEXT: ld1d { z1.d }, p0/z, [z3.d]
221+
; CHECK-NEXT: mad z0.d, p1/m, z6.d, z1.d
222+
; CHECK-NEXT: st1d { z0.d }, p0, [z3.d]
223+
; CHECK-NEXT: ret
224+
call void @llvm.experimental.vector.histogram.add.nxv8p0.i64(<vscale x 8 x ptr> %buckets, i64 %inc, <vscale x 8 x i1> %mask)
225+
ret void
226+
}
227+
228+
define void @histogram_i32_8_lane(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
229+
; CHECK-LABEL: histogram_i32_8_lane:
230+
; CHECK: // %bb.0:
231+
; CHECK-NEXT: punpklo p1.h, p0.b
232+
; CHECK-NEXT: mov z4.s, w1
233+
; CHECK-NEXT: ptrue p2.s
234+
; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s
235+
; CHECK-NEXT: ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
236+
; CHECK-NEXT: punpkhi p0.h, p0.b
237+
; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s
238+
; CHECK-NEXT: st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
239+
; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s
240+
; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
241+
; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s
242+
; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
243+
; CHECK-NEXT: ret
244+
%buckets = getelementptr i32, ptr %base, <vscale x 8 x i32> %indices
245+
call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
246+
ret void
247+
}
248+
249+
define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %inc, <vscale x 8 x i1> %mask) #0 {
250+
; CHECK-LABEL: histogram_i16_8_lane:
251+
; CHECK: // %bb.0:
252+
; CHECK-NEXT: punpklo p1.h, p0.b
253+
; CHECK-NEXT: mov z4.s, w1
254+
; CHECK-NEXT: ptrue p2.s
255+
; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s
256+
; CHECK-NEXT: ld1h { z3.s }, p1/z, [x0, z0.s, sxtw #1]
257+
; CHECK-NEXT: punpkhi p0.h, p0.b
258+
; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s
259+
; CHECK-NEXT: st1h { z2.s }, p1, [x0, z0.s, sxtw #1]
260+
; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s
261+
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z1.s, sxtw #1]
262+
; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s
263+
; CHECK-NEXT: st1h { z0.s }, p0, [x0, z1.s, sxtw #1]
264+
; CHECK-NEXT: ret
265+
%buckets = getelementptr i16, ptr %base, <vscale x 8 x i32> %indices
266+
call void @llvm.experimental.vector.histogram.add.nxv8p0.i16(<vscale x 8 x ptr> %buckets, i16 %inc, <vscale x 8 x i1> %mask)
267+
ret void
268+
}
269+
270+
172271
attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }

0 commit comments

Comments
 (0)