Skip to content

Commit 534b26a

Browse files
author
Krzysztof Parzyszek
committed
[Hexagon] Improve inserting/extracting to/from scalar predicates
Fixes llvm/llvm-project#59042.
1 parent f34fe2a commit 534b26a

File tree

4 files changed

+297
-78
lines changed

4 files changed

+297
-78
lines changed

llvm/lib/Target/Hexagon/HexagonISelLowering.cpp

Lines changed: 108 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2641,60 +2641,13 @@ HexagonTargetLowering::extractVector(SDValue VecV, SDValue IdxV,
26412641
MVT VecTy = ty(VecV);
26422642
assert(!ValTy.isVector() ||
26432643
VecTy.getVectorElementType() == ValTy.getVectorElementType());
2644+
if (VecTy.getVectorElementType() == MVT::i1)
2645+
return extractVectorPred(VecV, IdxV, dl, ValTy, ResTy, DAG);
2646+
26442647
unsigned VecWidth = VecTy.getSizeInBits();
26452648
unsigned ValWidth = ValTy.getSizeInBits();
26462649
unsigned ElemWidth = VecTy.getVectorElementType().getSizeInBits();
26472650
assert((VecWidth % ElemWidth) == 0);
2648-
auto *IdxN = dyn_cast<ConstantSDNode>(IdxV);
2649-
2650-
// Special case for v{8,4,2}i1 (the only boolean vectors legal in Hexagon
2651-
// without any coprocessors).
2652-
if (ElemWidth == 1) {
2653-
assert(VecWidth == VecTy.getVectorNumElements() &&
2654-
"Vector elements should equal vector width size");
2655-
assert(VecWidth == 8 || VecWidth == 4 || VecWidth == 2);
2656-
// Check if this is an extract of the lowest bit.
2657-
if (IdxN) {
2658-
// Extracting the lowest bit is a no-op, but it changes the type,
2659-
// so it must be kept as an operation to avoid errors related to
2660-
// type mismatches.
2661-
if (IdxN->isZero() && ValTy.getSizeInBits() == 1)
2662-
return DAG.getNode(HexagonISD::TYPECAST, dl, MVT::i1, VecV);
2663-
}
2664-
2665-
// If the value extracted is a single bit, use tstbit.
2666-
if (ValWidth == 1) {
2667-
SDValue A0 = getInstr(Hexagon::C2_tfrpr, dl, MVT::i32, {VecV}, DAG);
2668-
SDValue M0 = DAG.getConstant(8 / VecWidth, dl, MVT::i32);
2669-
SDValue I0 = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, M0);
2670-
return DAG.getNode(HexagonISD::TSTBIT, dl, MVT::i1, A0, I0);
2671-
}
2672-
2673-
// Each bool vector (v2i1, v4i1, v8i1) always occupies 8 bits in
2674-
// a predicate register. The elements of the vector are repeated
2675-
// in the register (if necessary) so that the total number is 8.
2676-
// The extracted subvector will need to be expanded in such a way.
2677-
unsigned Scale = VecWidth / ValWidth;
2678-
2679-
// Generate (p2d VecV) >> 8*Idx to move the interesting bytes to
2680-
// position 0.
2681-
assert(ty(IdxV) == MVT::i32);
2682-
unsigned VecRep = 8 / VecWidth;
2683-
SDValue S0 = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV,
2684-
DAG.getConstant(8*VecRep, dl, MVT::i32));
2685-
SDValue T0 = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, VecV);
2686-
SDValue T1 = DAG.getNode(ISD::SRL, dl, MVT::i64, T0, S0);
2687-
while (Scale > 1) {
2688-
// The longest possible subvector is at most 32 bits, so it is always
2689-
// contained in the low subregister.
2690-
T1 = DAG.getTargetExtractSubreg(Hexagon::isub_lo, dl, MVT::i32, T1);
2691-
T1 = expandPredicate(T1, dl, DAG);
2692-
Scale /= 2;
2693-
}
2694-
2695-
return DAG.getNode(HexagonISD::D2P, dl, ResTy, T1);
2696-
}
2697-
26982651
assert(VecWidth == 32 || VecWidth == 64);
26992652

27002653
// Cast everything to scalar integer types.
@@ -2704,7 +2657,7 @@ HexagonTargetLowering::extractVector(SDValue VecV, SDValue IdxV,
27042657
SDValue WidthV = DAG.getConstant(ValWidth, dl, MVT::i32);
27052658
SDValue ExtV;
27062659

2707-
if (IdxN) {
2660+
if (auto *IdxN = dyn_cast<ConstantSDNode>(IdxV)) {
27082661
unsigned Off = IdxN->getZExtValue() * ElemWidth;
27092662
if (VecWidth == 64 && ValWidth == 32) {
27102663
assert(Off == 0 || Off == 32);
@@ -2735,36 +2688,68 @@ HexagonTargetLowering::extractVector(SDValue VecV, SDValue IdxV,
27352688
}
27362689

27372690
SDValue
2738-
HexagonTargetLowering::insertVector(SDValue VecV, SDValue ValV, SDValue IdxV,
2739-
const SDLoc &dl, MVT ValTy,
2740-
SelectionDAG &DAG) const {
2691+
HexagonTargetLowering::extractVectorPred(SDValue VecV, SDValue IdxV,
2692+
const SDLoc &dl, MVT ValTy, MVT ResTy,
2693+
SelectionDAG &DAG) const {
2694+
// Special case for v{8,4,2}i1 (the only boolean vectors legal in Hexagon
2695+
// without any coprocessors).
27412696
MVT VecTy = ty(VecV);
2742-
if (VecTy.getVectorElementType() == MVT::i1) {
2743-
MVT ValTy = ty(ValV);
2744-
assert(ValTy.getVectorElementType() == MVT::i1);
2745-
SDValue ValR = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, ValV);
2746-
unsigned VecLen = VecTy.getVectorNumElements();
2747-
unsigned Scale = VecLen / ValTy.getVectorNumElements();
2748-
assert(Scale > 1);
2749-
2750-
for (unsigned R = Scale; R > 1; R /= 2) {
2751-
ValR = contractPredicate(ValR, dl, DAG);
2752-
ValR = getCombine(DAG.getUNDEF(MVT::i32), ValR, dl, MVT::i64, DAG);
2753-
}
2697+
unsigned VecWidth = VecTy.getSizeInBits();
2698+
unsigned ValWidth = ValTy.getSizeInBits();
2699+
assert(VecWidth == VecTy.getVectorNumElements() &&
2700+
"Vector elements should equal vector width size");
2701+
assert(VecWidth == 8 || VecWidth == 4 || VecWidth == 2);
2702+
2703+
// Check if this is an extract of the lowest bit.
2704+
if (auto *IdxN = dyn_cast<ConstantSDNode>(IdxV)) {
2705+
// Extracting the lowest bit is a no-op, but it changes the type,
2706+
// so it must be kept as an operation to avoid errors related to
2707+
// type mismatches.
2708+
if (IdxN->isZero() && ValTy.getSizeInBits() == 1)
2709+
return DAG.getNode(HexagonISD::TYPECAST, dl, MVT::i1, VecV);
2710+
}
2711+
2712+
// If the value extracted is a single bit, use tstbit.
2713+
if (ValWidth == 1) {
2714+
SDValue A0 = getInstr(Hexagon::C2_tfrpr, dl, MVT::i32, {VecV}, DAG);
2715+
SDValue M0 = DAG.getConstant(8 / VecWidth, dl, MVT::i32);
2716+
SDValue I0 = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, M0);
2717+
return DAG.getNode(HexagonISD::TSTBIT, dl, MVT::i1, A0, I0);
2718+
}
2719+
2720+
// Each bool vector (v2i1, v4i1, v8i1) always occupies 8 bits in
2721+
// a predicate register. The elements of the vector are repeated
2722+
// in the register (if necessary) so that the total number is 8.
2723+
// The extracted subvector will need to be expanded in such a way.
2724+
unsigned Scale = VecWidth / ValWidth;
2725+
2726+
// Generate (p2d VecV) >> 8*Idx to move the interesting bytes to
2727+
// position 0.
2728+
assert(ty(IdxV) == MVT::i32);
2729+
unsigned VecRep = 8 / VecWidth;
2730+
SDValue S0 = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV,
2731+
DAG.getConstant(8*VecRep, dl, MVT::i32));
2732+
SDValue T0 = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, VecV);
2733+
SDValue T1 = DAG.getNode(ISD::SRL, dl, MVT::i64, T0, S0);
2734+
while (Scale > 1) {
27542735
// The longest possible subvector is at most 32 bits, so it is always
27552736
// contained in the low subregister.
2756-
ValR = DAG.getTargetExtractSubreg(Hexagon::isub_lo, dl, MVT::i32, ValR);
2757-
2758-
unsigned ValBytes = 64 / Scale;
2759-
SDValue Width = DAG.getConstant(ValBytes*8, dl, MVT::i32);
2760-
SDValue Idx = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV,
2761-
DAG.getConstant(8, dl, MVT::i32));
2762-
SDValue VecR = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, VecV);
2763-
SDValue Ins = DAG.getNode(HexagonISD::INSERT, dl, MVT::i32,
2764-
{VecR, ValR, Width, Idx});
2765-
return DAG.getNode(HexagonISD::D2P, dl, VecTy, Ins);
2737+
T1 = DAG.getTargetExtractSubreg(Hexagon::isub_lo, dl, MVT::i32, T1);
2738+
T1 = expandPredicate(T1, dl, DAG);
2739+
Scale /= 2;
27662740
}
27672741

2742+
return DAG.getNode(HexagonISD::D2P, dl, ResTy, T1);
2743+
}
2744+
2745+
SDValue
2746+
HexagonTargetLowering::insertVector(SDValue VecV, SDValue ValV, SDValue IdxV,
2747+
const SDLoc &dl, MVT ValTy,
2748+
SelectionDAG &DAG) const {
2749+
MVT VecTy = ty(VecV);
2750+
if (VecTy.getVectorElementType() == MVT::i1)
2751+
return insertVectorPred(VecV, ValV, IdxV, dl, ValTy, DAG);
2752+
27682753
unsigned VecWidth = VecTy.getSizeInBits();
27692754
unsigned ValWidth = ValTy.getSizeInBits();
27702755
assert(VecWidth == 32 || VecWidth == 64);
@@ -2799,13 +2784,53 @@ HexagonTargetLowering::insertVector(SDValue VecV, SDValue ValV, SDValue IdxV,
27992784
return DAG.getNode(ISD::BITCAST, dl, VecTy, InsV);
28002785
}
28012786

2787+
SDValue
2788+
HexagonTargetLowering::insertVectorPred(SDValue VecV, SDValue ValV,
2789+
SDValue IdxV, const SDLoc &dl,
2790+
MVT ValTy, SelectionDAG &DAG) const {
2791+
MVT VecTy = ty(VecV);
2792+
unsigned VecLen = VecTy.getVectorNumElements();
2793+
2794+
if (ValTy == MVT::i1) {
2795+
SDValue ToReg = getInstr(Hexagon::C2_tfrpr, dl, MVT::i32, {VecV}, DAG);
2796+
SDValue Ext = DAG.getSExtOrTrunc(ValV, dl, MVT::i32);
2797+
SDValue Width = DAG.getConstant(8 / VecLen, dl, MVT::i32);
2798+
SDValue Idx = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, Width);
2799+
SDValue Ins =
2800+
DAG.getNode(HexagonISD::INSERT, dl, MVT::i32, {ToReg, Ext, Width, Idx});
2801+
return getInstr(Hexagon::C2_tfrrp, dl, VecTy, {Ins}, DAG);
2802+
}
2803+
2804+
assert(ValTy.getVectorElementType() == MVT::i1);
2805+
SDValue ValR = ValTy.isVector()
2806+
? DAG.getNode(HexagonISD::P2D, dl, MVT::i64, ValV)
2807+
: DAG.getSExtOrTrunc(ValV, dl, MVT::i64);
2808+
2809+
unsigned Scale = VecLen / ValTy.getVectorNumElements();
2810+
assert(Scale > 1);
2811+
2812+
for (unsigned R = Scale; R > 1; R /= 2) {
2813+
ValR = contractPredicate(ValR, dl, DAG);
2814+
ValR = getCombine(DAG.getUNDEF(MVT::i32), ValR, dl, MVT::i64, DAG);
2815+
}
2816+
2817+
SDValue Width = DAG.getConstant(64 / Scale, dl, MVT::i32);
2818+
SDValue Idx = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, Width);
2819+
SDValue VecR = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, VecV);
2820+
SDValue Ins =
2821+
DAG.getNode(HexagonISD::INSERT, dl, MVT::i64, {VecR, ValR, Width, Idx});
2822+
return DAG.getNode(HexagonISD::D2P, dl, VecTy, Ins);
2823+
}
2824+
28022825
SDValue
28032826
HexagonTargetLowering::expandPredicate(SDValue Vec32, const SDLoc &dl,
28042827
SelectionDAG &DAG) const {
28052828
assert(ty(Vec32).getSizeInBits() == 32);
28062829
if (isUndef(Vec32))
28072830
return DAG.getUNDEF(MVT::i64);
2808-
return getInstr(Hexagon::S2_vsxtbh, dl, MVT::i64, {Vec32}, DAG);
2831+
SDValue P = DAG.getBitcast(MVT::v4i8, Vec32);
2832+
SDValue X = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i16, P);
2833+
return DAG.getBitcast(MVT::i64, X);
28092834
}
28102835

28112836
SDValue
@@ -2814,7 +2839,12 @@ HexagonTargetLowering::contractPredicate(SDValue Vec64, const SDLoc &dl,
28142839
assert(ty(Vec64).getSizeInBits() == 64);
28152840
if (isUndef(Vec64))
28162841
return DAG.getUNDEF(MVT::i32);
2817-
return getInstr(Hexagon::S2_vtrunehb, dl, MVT::i32, {Vec64}, DAG);
2842+
// Collect even bytes:
2843+
SDValue A = DAG.getBitcast(MVT::v8i8, Vec64);
2844+
SDValue S = DAG.getVectorShuffle(MVT::v8i8, dl, A, DAG.getUNDEF(MVT::v8i8),
2845+
{0, 2, 4, 6, 1, 3, 5, 7});
2846+
return extractVector(S, DAG.getConstant(0, dl, MVT::i32), dl, MVT::v4i8,
2847+
MVT::i32, DAG);
28182848
}
28192849

28202850
SDValue

llvm/lib/Target/Hexagon/HexagonISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,12 @@ class HexagonTargetLowering : public TargetLowering {
378378
SelectionDAG &DAG) const;
379379
SDValue extractVector(SDValue VecV, SDValue IdxV, const SDLoc &dl,
380380
MVT ValTy, MVT ResTy, SelectionDAG &DAG) const;
381+
SDValue extractVectorPred(SDValue VecV, SDValue IdxV, const SDLoc &dl,
382+
MVT ValTy, MVT ResTy, SelectionDAG &DAG) const;
381383
SDValue insertVector(SDValue VecV, SDValue ValV, SDValue IdxV,
382384
const SDLoc &dl, MVT ValTy, SelectionDAG &DAG) const;
385+
SDValue insertVectorPred(SDValue VecV, SDValue ValV, SDValue IdxV,
386+
const SDLoc &dl, MVT ValTy, SelectionDAG &DAG) const;
383387
SDValue expandPredicate(SDValue Vec32, const SDLoc &dl,
384388
SelectionDAG &DAG) const;
385389
SDValue contractPredicate(SDValue Vec64, const SDLoc &dl,
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -march=hexagon < %s | FileCheck %s
3+
4+
define i32 @f0(ptr %a0, i32 %a1) #0 {
5+
; CHECK-LABEL: f0:
6+
; CHECK: // %bb.0:
7+
; CHECK-NEXT: {
8+
; CHECK-NEXT: r0 = memub(r0+#0)
9+
; CHECK-NEXT: }
10+
; CHECK-NEXT: {
11+
; CHECK-NEXT: r1 = asl(r1,#2)
12+
; CHECK-NEXT: }
13+
; CHECK-NEXT: {
14+
; CHECK-NEXT: p0 = tstbit(r0,r1)
15+
; CHECK-NEXT: }
16+
; CHECK-NEXT: {
17+
; CHECK-NEXT: r0 = mux(p0,#-1,#0)
18+
; CHECK-NEXT: }
19+
; CHECK-NEXT: {
20+
; CHECK-NEXT: jumpr r31
21+
; CHECK-NEXT: }
22+
%v0 = load <2 x i1>, ptr %a0
23+
%v1 = extractelement <2 x i1> %v0, i32 %a1
24+
%v2 = sext i1 %v1 to i32
25+
ret i32 %v2
26+
}
27+
28+
define i32 @f1(ptr %a0, i32 %a1) #0 {
29+
; CHECK-LABEL: f1:
30+
; CHECK: // %bb.0:
31+
; CHECK-NEXT: {
32+
; CHECK-NEXT: r0 = memub(r0+#0)
33+
; CHECK-NEXT: }
34+
; CHECK-NEXT: {
35+
; CHECK-NEXT: r1 = asl(r1,#1)
36+
; CHECK-NEXT: }
37+
; CHECK-NEXT: {
38+
; CHECK-NEXT: p0 = tstbit(r0,r1)
39+
; CHECK-NEXT: }
40+
; CHECK-NEXT: {
41+
; CHECK-NEXT: r0 = mux(p0,#-1,#0)
42+
; CHECK-NEXT: }
43+
; CHECK-NEXT: {
44+
; CHECK-NEXT: jumpr r31
45+
; CHECK-NEXT: }
46+
%v0 = load <4 x i1>, ptr %a0
47+
%v1 = extractelement <4 x i1> %v0, i32 %a1
48+
%v2 = sext i1 %v1 to i32
49+
ret i32 %v2
50+
}
51+
52+
define i32 @f2(ptr %a0, i32 %a1) #0 {
53+
; CHECK-LABEL: f2:
54+
; CHECK: // %bb.0:
55+
; CHECK-NEXT: {
56+
; CHECK-NEXT: r0 = memub(r0+#0)
57+
; CHECK-NEXT: }
58+
; CHECK-NEXT: {
59+
; CHECK-NEXT: p0 = tstbit(r0,r1)
60+
; CHECK-NEXT: }
61+
; CHECK-NEXT: {
62+
; CHECK-NEXT: r0 = mux(p0,#-1,#0)
63+
; CHECK-NEXT: }
64+
; CHECK-NEXT: {
65+
; CHECK-NEXT: jumpr r31
66+
; CHECK-NEXT: }
67+
%v0 = load <8 x i1>, ptr %a0
68+
%v1 = extractelement <8 x i1> %v0, i32 %a1
69+
%v2 = sext i1 %v1 to i32
70+
ret i32 %v2
71+
}
72+
73+
attributes #0 = { nounwind "target-features"="-packets" }

0 commit comments

Comments
 (0)