Skip to content

Commit a0873ee

Browse files
committed
Responding to review comments
- making sure transform only operates on smin nodes - adding extra tests dealing with interesting edge cases Change-Id: Ia1114ec9b93c4de3552b867e0d745beccdae69f1
1 parent 7e977ac commit a0873ee

File tree

2 files changed

+97
-18
lines changed

2 files changed

+97
-18
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
11431143
ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS,
11441144
ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR,
11451145
ISD::STORE, ISD::BUILD_VECTOR});
1146+
setTargetDAGCombine(ISD::SMIN);
11461147
setTargetDAGCombine(ISD::TRUNCATE);
11471148
setTargetDAGCombine(ISD::LOAD);
11481149

@@ -20998,23 +20999,18 @@ static SDValue performBuildVectorCombine(SDNode *N,
2099820999

2099921000
// A special combine for the sqdmulh family of instructions.
2100021001
// smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
21001-
// SATURATING_VAL ) can be reduced to sext(sqdmulh(...))
21002+
// SATURATING_VAL ) can be reduced to sqdmulh(...)
2100221003
static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
2100321004

21004-
if (N->getOpcode() != ISD::TRUNCATE)
21005+
if (N->getOpcode() != ISD::SMIN)
2100521006
return SDValue();
2100621007

2100721008
EVT VT = N->getValueType(0);
2100821009

2100921010
if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
2101021011
return SDValue();
2101121012

21012-
SDValue SMin = N->getOperand(0);
21013-
21014-
if (SMin.getOpcode() != ISD::SMIN)
21015-
return SDValue();
21016-
21017-
ConstantSDNode *Clamp = isConstOrConstSplat(SMin.getOperand(1));
21013+
ConstantSDNode *Clamp = isConstOrConstSplat(N->getOperand(1));
2101821014

2101921015
if (!Clamp)
2102021016
return SDValue();
@@ -21034,8 +21030,8 @@ static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
2103421030
return SDValue();
2103521031
}
2103621032

21037-
SDValue Sra = SMin.getOperand(0);
21038-
if (Sra.getOpcode() != ISD::SRA)
21033+
SDValue Sra = N->getOperand(0);
21034+
if (Sra.getOpcode() != ISD::SRA || !Sra.hasOneUse())
2103921035
return SDValue();
2104021036

2104121037
ConstantSDNode *RightShiftVec = isConstOrConstSplat(Sra.getOperand(1));
@@ -21062,11 +21058,27 @@ static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
2106221058
SExt0Type.getFixedSizeInBits() > 128)
2106321059
return SDValue();
2106421060

21065-
SDValue V0 = SExt0.getOperand(0);
21066-
SDValue V1 = SExt1.getOperand(0);
21061+
// Source vectors with width < 64 are illegal and will need to be extended
21062+
unsigned SourceVectorWidth = SExt0Type.getFixedSizeInBits();
21063+
SDValue V0 = (SourceVectorWidth < 64) ? SExt0 : SExt0.getOperand(0);
21064+
SDValue V1 = (SourceVectorWidth < 64) ? SExt1 : SExt1.getOperand(0);
21065+
21066+
SDLoc DL(N);
21067+
SDValue SQDMULH =
21068+
DAG.getNode(AArch64ISD::SQDMULH, DL, V0.getValueType(), V0, V1);
21069+
EVT DestVT = N->getValueType(0);
21070+
if (DestVT.getScalarSizeInBits() > SExt0Type.getScalarSizeInBits())
21071+
return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH);
21072+
21073+
return SQDMULH;
21074+
}
21075+
21076+
static SDValue performSMINCombine(SDNode *N, SelectionDAG &DAG) {
21077+
if (SDValue V = trySQDMULHCombine(N, DAG)) {
21078+
return V;
21079+
}
2106721080

21068-
SDLoc DL(SMin);
21069-
return DAG.getNode(AArch64ISD::SQDMULH, DL, SExt0Type, V0, V1);
21081+
return SDValue();
2107021082
}
2107121083

2107221084
static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
@@ -21083,10 +21095,6 @@ static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
2108321095
return DAG.getNode(N0.getOpcode(), DL, VT, Op);
2108421096
}
2108521097

21086-
if (SDValue V = trySQDMULHCombine(N, DAG)) {
21087-
return DAG.getNode(ISD::TRUNCATE, DL, VT, V);
21088-
}
21089-
2109021098
// Performing the following combine produces a preferable form for ISEL.
2109121099
// i32 (trunc (extract Vi64, idx)) -> i32 (extract (nvcast Vi32), idx*2))
2109221100
if (DCI.isAfterLegalizeDAG() && N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
@@ -26824,6 +26832,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2682426832
return performAddSubCombine(N, DCI);
2682526833
case ISD::BUILD_VECTOR:
2682626834
return performBuildVectorCombine(N, DCI, DAG);
26835+
case ISD::SMIN:
26836+
return performSMINCombine(N, DAG);
2682726837
case ISD::TRUNCATE:
2682826838
return performTruncateCombine(N, DAG, DCI);
2682926839
case AArch64ISD::ANDS:

llvm/test/CodeGen/AArch64/saturating-vec-smull.ll

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
22
; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s
33

4+
5+
define <2 x i16> @saturating_2xi16(<2 x i16> %a, <2 x i16> %b) {
6+
; CHECK-LABEL: saturating_2xi16:
7+
; CHECK: // %bb.0:
8+
; CHECK-NEXT: shl v0.2s, v0.2s, #16
9+
; CHECK-NEXT: shl v1.2s, v1.2s, #16
10+
; CHECK-NEXT: sshr v0.2s, v0.2s, #16
11+
; CHECK-NEXT: sshr v1.2s, v1.2s, #16
12+
; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
13+
; CHECK-NEXT: ret
14+
%as = sext <2 x i16> %a to <2 x i32>
15+
%bs = sext <2 x i16> %b to <2 x i32>
16+
%m = mul <2 x i32> %bs, %as
17+
%sh = ashr <2 x i32> %m, splat (i32 15)
18+
%ma = tail call <2 x i32> @llvm.smin.v4i32(<2 x i32> %sh, <2 x i32> splat (i32 32767))
19+
%t = trunc <2 x i32> %ma to <2 x i16>
20+
ret <2 x i16> %t
21+
}
22+
423
define <4 x i16> @saturating_4xi16(<4 x i16> %a, <4 x i16> %b) {
524
; CHECK-LABEL: saturating_4xi16:
625
; CHECK: // %bb.0:
@@ -71,3 +90,53 @@ define <8 x i32> @saturating_8xi32(<8 x i32> %a, <8 x i32> %b) {
7190
%t = trunc <8 x i64> %ma to <8 x i32>
7291
ret <8 x i32> %t
7392
}
93+
94+
define <2 x i64> @saturating_2xi32_2xi64(<2 x i32> %a, <2 x i32> %b) {
95+
; CHECK-LABEL: saturating_2xi32_2xi64:
96+
; CHECK: // %bb.0:
97+
; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
98+
; CHECK-NEXT: sshll v0.2d, v0.2s, #0
99+
; CHECK-NEXT: ret
100+
%as = sext <2 x i32> %a to <2 x i64>
101+
%bs = sext <2 x i32> %b to <2 x i64>
102+
%m = mul <2 x i64> %bs, %as
103+
%sh = ashr <2 x i64> %m, splat (i64 31)
104+
%ma = tail call <2 x i64> @llvm.smin.v8i64(<2 x i64> %sh, <2 x i64> splat (i64 2147483647))
105+
ret <2 x i64> %ma
106+
}
107+
108+
define <4 x i16> @unsupported_saturation_value_v4i16(<4 x i16> %a, <4 x i16> %b) {
109+
; CHECK-LABEL: unsupported_saturation_value_v4i16:
110+
; CHECK: // %bb.0:
111+
; CHECK-NEXT: smull v0.4s, v1.4h, v0.4h
112+
; CHECK-NEXT: movi v1.4s, #42
113+
; CHECK-NEXT: sshr v0.4s, v0.4s, #15
114+
; CHECK-NEXT: smin v0.4s, v0.4s, v1.4s
115+
; CHECK-NEXT: xtn v0.4h, v0.4s
116+
; CHECK-NEXT: ret
117+
%as = sext <4 x i16> %a to <4 x i32>
118+
%bs = sext <4 x i16> %b to <4 x i32>
119+
%m = mul <4 x i32> %bs, %as
120+
%sh = ashr <4 x i32> %m, splat (i32 15)
121+
%ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 42))
122+
%t = trunc <4 x i32> %ma to <4 x i16>
123+
ret <4 x i16> %t
124+
}
125+
126+
define <4 x i16> @unsupported_shift_value_v4i16(<4 x i16> %a, <4 x i16> %b) {
127+
; CHECK-LABEL: unsupported_shift_value_v4i16:
128+
; CHECK: // %bb.0:
129+
; CHECK-NEXT: smull v0.4s, v1.4h, v0.4h
130+
; CHECK-NEXT: movi v1.4s, #127, msl #8
131+
; CHECK-NEXT: sshr v0.4s, v0.4s, #3
132+
; CHECK-NEXT: smin v0.4s, v0.4s, v1.4s
133+
; CHECK-NEXT: xtn v0.4h, v0.4s
134+
; CHECK-NEXT: ret
135+
%as = sext <4 x i16> %a to <4 x i32>
136+
%bs = sext <4 x i16> %b to <4 x i32>
137+
%m = mul <4 x i32> %bs, %as
138+
%sh = ashr <4 x i32> %m, splat (i32 3)
139+
%ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 32767))
140+
%t = trunc <4 x i32> %ma to <4 x i16>
141+
ret <4 x i16> %t
142+
}

0 commit comments

Comments
 (0)