Skip to content

Commit 1bb0ad6

Browse files
committed
Response to review comments addressing:
- check for scalar type - check for sign extends - legalise vector inputs to sqdmulh - always return sext(sqdmulh) Change-Id: Ic58b7f267e94bc2592942fc29b829ffb6221770f
1 parent a0873ee commit 1bb0ad6

File tree

2 files changed

+61
-13
lines changed

2 files changed

+61
-13
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21005,9 +21005,9 @@ static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
2100521005
if (N->getOpcode() != ISD::SMIN)
2100621006
return SDValue();
2100721007

21008-
EVT VT = N->getValueType(0);
21008+
EVT DestVT = N->getValueType(0);
2100921009

21010-
if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
21010+
if (!DestVT.isVector() || DestVT.getScalarSizeInBits() > 64 || DestVT.isScalableVector())
2101121011
return SDValue();
2101221012

2101321013
ConstantSDNode *Clamp = isConstOrConstSplat(N->getOperand(1));
@@ -21049,28 +21049,36 @@ static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
2104921049
SDValue SExt0 = Mul.getOperand(0);
2105021050
SDValue SExt1 = Mul.getOperand(1);
2105121051

21052+
if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
21053+
SExt1.getOpcode() != ISD::SIGN_EXTEND)
21054+
return SDValue();
21055+
2105221056
EVT SExt0Type = SExt0.getOperand(0).getValueType();
2105321057
EVT SExt1Type = SExt1.getOperand(0).getValueType();
2105421058

21055-
if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
21056-
SExt1.getOpcode() != ISD::SIGN_EXTEND || SExt0Type != SExt1Type ||
21059+
if (SExt0Type != SExt1Type ||
2105721060
SExt0Type.getScalarType() != ScalarType ||
2105821061
SExt0Type.getFixedSizeInBits() > 128)
2105921062
return SDValue();
2106021063

21061-
// Source vectors with width < 64 are illegal and will need to be extended
21062-
unsigned SourceVectorWidth = SExt0Type.getFixedSizeInBits();
21063-
SDValue V0 = (SourceVectorWidth < 64) ? SExt0 : SExt0.getOperand(0);
21064-
SDValue V1 = (SourceVectorWidth < 64) ? SExt1 : SExt1.getOperand(0);
21065-
2106621064
SDLoc DL(N);
21065+
SDValue V0 = SExt0.getOperand(0);
21066+
SDValue V1 = SExt1.getOperand(0);
21067+
21068+
// Ensure input vectors are extended to legal types
21069+
if (SExt0Type.getFixedSizeInBits() < 64) {
21070+
unsigned VecNumElements = SExt0Type.getVectorNumElements();
21071+
EVT ExtVecVT =
21072+
MVT::getVectorVT(MVT::getIntegerVT(64 / VecNumElements),
21073+
VecNumElements);
21074+
V0 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V0);
21075+
V1 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V1);
21076+
}
21077+
2106721078
SDValue SQDMULH =
2106821079
DAG.getNode(AArch64ISD::SQDMULH, DL, V0.getValueType(), V0, V1);
21069-
EVT DestVT = N->getValueType(0);
21070-
if (DestVT.getScalarSizeInBits() > SExt0Type.getScalarSizeInBits())
21071-
return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH);
2107221080

21073-
return SQDMULH;
21081+
return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH);
2107421082
}
2107521083

2107621084
static SDValue performSMINCombine(SDNode *N, SelectionDAG &DAG) {

llvm/test/CodeGen/AArch64/saturating-vec-smull.ll

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,43 @@ define <4 x i16> @unsupported_shift_value_v4i16(<4 x i16> %a, <4 x i16> %b) {
140140
%t = trunc <4 x i32> %ma to <4 x i16>
141141
ret <4 x i16> %t
142142
}
143+
144+
define <2 x i16> @extend_to_illegal_type(<2 x i16> %a, <2 x i16> %b) {
145+
; CHECK-LABEL: extend_to_illegal_type:
146+
; CHECK: // %bb.0:
147+
; CHECK-NEXT: shl v0.2s, v0.2s, #16
148+
; CHECK-NEXT: shl v1.2s, v1.2s, #16
149+
; CHECK-NEXT: sshr v0.2s, v0.2s, #16
150+
; CHECK-NEXT: sshr v1.2s, v1.2s, #16
151+
; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
152+
; CHECK-NEXT: ret
153+
%as = sext <2 x i16> %a to <2 x i48>
154+
%bs = sext <2 x i16> %b to <2 x i48>
155+
%m = mul <2 x i48> %bs, %as
156+
%sh = ashr <2 x i48> %m, splat (i48 15)
157+
%ma = tail call <2 x i48> @llvm.smin.v4i32(<2 x i48> %sh, <2 x i48> splat (i48 32767))
158+
%t = trunc <2 x i48> %ma to <2 x i16>
159+
ret <2 x i16> %t
160+
}
161+
162+
define <2 x i11> @illegal_source(<2 x i11> %a, <2 x i11> %b) {
163+
; CHECK-LABEL: source_is_illegal:
164+
; CHECK: // %bb.0:
165+
; CHECK-NEXT: shl v0.2s, v0.2s, #21
166+
; CHECK-NEXT: shl v1.2s, v1.2s, #21
167+
; CHECK-NEXT: sshr v0.2s, v0.2s, #21
168+
; CHECK-NEXT: sshr v1.2s, v1.2s, #21
169+
; CHECK-NEXT: mul v0.2s, v1.2s, v0.2s
170+
; CHECK-NEXT: movi v1.2s, #127, msl #8
171+
; CHECK-NEXT: sshr v0.2s, v0.2s, #15
172+
; CHECK-NEXT: smin v0.2s, v0.2s, v1.2s
173+
; CHECK-NEXT: ret
174+
%as = sext <2 x i11> %a to <2 x i32>
175+
%bs = sext <2 x i11> %b to <2 x i32>
176+
%m = mul <2 x i32> %bs, %as
177+
%sh = ashr <2 x i32> %m, splat (i32 15)
178+
%ma = tail call <2 x i32> @llvm.smin.v4i32(<2 x i32> %sh, <2 x i32> splat (i32 32767))
179+
%t = trunc <2 x i32> %ma to <2 x i11>
180+
ret <2 x i11> %t
181+
}
182+

0 commit comments

Comments
 (0)