Skip to content

Commit 70fd5a4

Browse files
committed
[NVPTX] Use cvt.sat to lower min/max clamping to i8 and i16 ranges
1 parent 8d86963 commit 70fd5a4

File tree

4 files changed

+278
-1
lines changed

4 files changed

+278
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "llvm/CodeGen/MachineFunction.h"
2929
#include "llvm/CodeGen/MachineJumpTableInfo.h"
3030
#include "llvm/CodeGen/MachineMemOperand.h"
31+
#include "llvm/CodeGen/SDPatternMatch.h"
3132
#include "llvm/CodeGen/SelectionDAG.h"
3233
#include "llvm/CodeGen/SelectionDAGNodes.h"
3334
#include "llvm/CodeGen/TargetCallingConv.h"
@@ -74,6 +75,7 @@
7475
#define DEBUG_TYPE "nvptx-lower"
7576

7677
using namespace llvm;
78+
using namespace llvm::SDPatternMatch;
7779

7880
static cl::opt<bool> sched4reg(
7981
"nvptx-sched4reg",
@@ -659,6 +661,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
659661
setOperationAction(ISD::BR_CC, VT, Expand);
660662
}
661663

664+
setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i16,
665+
Legal);
666+
setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i8,
667+
Custom);
668+
662669
// Some SIGN_EXTEND_INREG can be done using cvt instruction.
663670
// For others we will expand to a SHL/SRA pair.
664671
setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal);
@@ -836,7 +843,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
836843
// We have some custom DAG combine patterns for these nodes
837844
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
838845
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
839-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
846+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::SMIN,
847+
ISD::SMAX});
840848

841849
// setcc for f16x2 and bf16x2 needs special handling to prevent
842850
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -1081,6 +1089,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10811089
MAKE_CASE(NVPTXISD::PseudoUseParam)
10821090
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
10831091
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
1092+
MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_U_I8)
1093+
MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_S_I8)
10841094
MAKE_CASE(NVPTXISD::RETURN)
10851095
MAKE_CASE(NVPTXISD::CallSeqBegin)
10861096
MAKE_CASE(NVPTXISD::CallSeqEnd)
@@ -5667,6 +5677,49 @@ static SDValue combineADDRSPACECAST(SDNode *N,
56675677
return SDValue();
56685678
}
56695679

5680+
static SDValue combineMINMAX(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5681+
5682+
EVT VT = N->getValueType(0);
5683+
if (!(VT == MVT::i32 || VT == MVT::i64 || VT == MVT::i16))
5684+
return SDValue();
5685+
5686+
SDValue Val;
5687+
APInt Ceil, Floor;
5688+
if (!(sd_match(N, m_SMin(m_SMax(m_Value(Val), m_ConstInt(Floor)),
5689+
m_ConstInt(Ceil))) ||
5690+
sd_match(N, m_SMax(m_SMin(m_Value(Val), m_ConstInt(Ceil)),
5691+
m_ConstInt(Floor)))))
5692+
return SDValue();
5693+
5694+
const unsigned BitWidth = VT.getSizeInBits();
5695+
SDLoc DL(N);
5696+
auto MatchTuncSat = [&](MVT DestVT) {
5697+
const unsigned DestBitWidth = DestVT.getSizeInBits();
5698+
bool IsSigned;
5699+
if (Ceil == APInt::getSignedMaxValue(DestBitWidth).sext(BitWidth) &&
5700+
Floor == APInt::getSignedMinValue(DestBitWidth).sext(BitWidth))
5701+
IsSigned = true;
5702+
else if (Ceil == APInt::getMaxValue(DestBitWidth).zext(BitWidth) &&
5703+
Floor == APInt::getMinValue(BitWidth))
5704+
IsSigned = false;
5705+
else
5706+
return SDValue();
5707+
5708+
unsigned Opcode = IsSigned ? ISD::TRUNCATE_SSAT_S : ISD::TRUNCATE_SSAT_U;
5709+
SDValue Trunc = DCI.DAG.getNode(Opcode, DL, DestVT, Val);
5710+
return DCI.DAG.getExtOrTrunc(IsSigned, Trunc, DL, VT);
5711+
};
5712+
5713+
if (VT != MVT::i16)
5714+
if (auto Res = MatchTuncSat(MVT::i16))
5715+
return Res;
5716+
5717+
if (auto Res = MatchTuncSat(MVT::i8))
5718+
return Res;
5719+
5720+
return SDValue();
5721+
}
5722+
56705723
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56715724
DAGCombinerInfo &DCI) const {
56725725
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5685,6 +5738,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56855738
case ISD::UREM:
56865739
case ISD::SREM:
56875740
return PerformREMCombine(N, DCI, OptLevel);
5741+
case ISD::SMIN:
5742+
case ISD::SMAX:
5743+
return combineMINMAX(N, DCI);
56885744
case ISD::SETCC:
56895745
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
56905746
case NVPTXISD::StoreRetval:
@@ -6045,6 +6101,20 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
60456101
Results.push_back(NewValue.getValue(3));
60466102
}
60476103

6104+
static void replaceTruncateSSat(SDNode *N, SelectionDAG &DAG,
6105+
SmallVectorImpl<SDValue> &Results) {
6106+
SDLoc DL(N);
6107+
6108+
const bool IsSigned = N->getOpcode() == ISD::TRUNCATE_SSAT_S;
6109+
const unsigned Opcode =
6110+
IsSigned ? NVPTXISD::TRUNCATE_SSAT_S_I8 : NVPTXISD::TRUNCATE_SSAT_U_I8;
6111+
SDValue NewTrunc = DAG.getNode(Opcode, DL, MVT::i16, N->getOperand(0));
6112+
SDValue Assert = DAG.getNode(IsSigned ? ISD::AssertSext : ISD::AssertZext, DL,
6113+
MVT::i16, NewTrunc, DAG.getValueType(MVT::i8));
6114+
6115+
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, Assert));
6116+
}
6117+
60486118
void NVPTXTargetLowering::ReplaceNodeResults(
60496119
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
60506120
switch (N->getOpcode()) {
@@ -6062,6 +6132,10 @@ void NVPTXTargetLowering::ReplaceNodeResults(
60626132
case ISD::CopyFromReg:
60636133
ReplaceCopyFromReg_128(N, DAG, Results);
60646134
return;
6135+
case ISD::TRUNCATE_SSAT_U:
6136+
case ISD::TRUNCATE_SSAT_S:
6137+
replaceTruncateSSat(N, DAG, Results);
6138+
return;
60656139
}
60666140
}
60676141

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ enum NodeType : unsigned {
7272
/// converting it to a vector.
7373
UNPACK_VECTOR,
7474

75+
/// These nodes are equivalent to the corresponding ISD nodes except that
76+
/// they truncate to an i8 output and then sign or zero extend that value back
77+
/// to i16. This is a workaround for the fact that NVPTX does not consider
78+
/// i8 to be a legal type. TODO: consider making i8 legal and removing these.
79+
TRUNCATE_SSAT_U_I8,
80+
TRUNCATE_SSAT_S_I8,
81+
7582
FCOPYSIGN,
7683
DYNAMIC_STACKALLOC,
7784
STACKRESTORE,

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2649,6 +2649,25 @@ def : Pat<(i1 (trunc i32:$a)), (SETP_b32ri (ANDb32ri $a, 1), 0, CmpNE)>;
26492649
// truncate i16
26502650
def : Pat<(i1 (trunc i16:$a)), (SETP_b16ri (ANDb16ri $a, 1), 0, CmpNE)>;
26512651

2652+
// truncate ssat
2653+
def SDTTruncSatI8Op : SDTypeProfile<1, 1, [SDTCisInt<1>, SDTCisVT<0, i16>]>;
2654+
def truncssat_s_i8 : SDNode<"NVPTXISD::TRUNCATE_SSAT_S_I8", SDTTruncSatI8Op>;
2655+
def truncssat_u_i8 : SDNode<"NVPTXISD::TRUNCATE_SSAT_U_I8", SDTTruncSatI8Op>;
2656+
2657+
def : Pat<(i16 (truncssat_s i32:$a)), (CVT_s16_s32 $a, CvtSAT)>;
2658+
def : Pat<(i16 (truncssat_s i64:$a)), (CVT_s16_s64 $a, CvtSAT)>;
2659+
2660+
def : Pat<(i16 (truncssat_u i32:$a)), (CVT_u16_s32 $a, CvtSAT)>;
2661+
def : Pat<(i16 (truncssat_u i64:$a)), (CVT_u16_s64 $a, CvtSAT)>;
2662+
2663+
def : Pat<(truncssat_s_i8 i16:$a), (CVT_s8_s16 $a, CvtSAT)>;
2664+
def : Pat<(truncssat_s_i8 i32:$a), (CVT_s8_s32 $a, CvtSAT)>;
2665+
def : Pat<(truncssat_s_i8 i64:$a), (CVT_s8_s64 $a, CvtSAT)>;
2666+
2667+
def : Pat<(truncssat_u_i8 i16:$a), (CVT_u8_u16 $a, CvtSAT)>;
2668+
def : Pat<(truncssat_u_i8 i32:$a), (CVT_u8_u32 $a, CvtSAT)>;
2669+
def : Pat<(truncssat_u_i8 i64:$a), (CVT_u8_u64 $a, CvtSAT)>;
2670+
26522671
// sext_inreg
26532672
def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>;
26542673
def : Pat<(sext_inreg i32:$a, i8), (CVT_INREG_s32_s8 $a)>;
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}
4+
5+
target triple = "nvptx-unknown-cuda"
6+
7+
8+
define i64 @trunc_ssat_i64_u16(i64 %a) {
9+
; CHECK-LABEL: trunc_ssat_i64_u16(
10+
; CHECK: {
11+
; CHECK-NEXT: .reg .b16 %rs<2>;
12+
; CHECK-NEXT: .reg .b64 %rd<3>;
13+
; CHECK-EMPTY:
14+
; CHECK-NEXT: // %bb.0:
15+
; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_u16_param_0];
16+
; CHECK-NEXT: cvt.sat.u16.s64 %rs1, %rd1;
17+
; CHECK-NEXT: cvt.u64.u16 %rd2, %rs1;
18+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
19+
; CHECK-NEXT: ret;
20+
%v1 = call i64 @llvm.smax.i64(i64 %a, i64 0)
21+
%v2 = call i64 @llvm.smin.i64(i64 %v1, i64 65535)
22+
ret i64 %v2
23+
}
24+
25+
define i32 @trunc_ssat_i32_u16(i32 %a) {
26+
; CHECK-LABEL: trunc_ssat_i32_u16(
27+
; CHECK: {
28+
; CHECK-NEXT: .reg .b16 %rs<2>;
29+
; CHECK-NEXT: .reg .b32 %r<3>;
30+
; CHECK-EMPTY:
31+
; CHECK-NEXT: // %bb.0:
32+
; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_u16_param_0];
33+
; CHECK-NEXT: cvt.sat.u16.s32 %rs1, %r1;
34+
; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
35+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
36+
; CHECK-NEXT: ret;
37+
%v1 = call i32 @llvm.smax.i32(i32 %a, i32 0)
38+
%v2 = call i32 @llvm.smin.i32(i32 %v1, i32 65535)
39+
ret i32 %v2
40+
}
41+
42+
define i64 @trunc_ssat_i64_s16(i64 %a) {
43+
; CHECK-LABEL: trunc_ssat_i64_s16(
44+
; CHECK: {
45+
; CHECK-NEXT: .reg .b16 %rs<2>;
46+
; CHECK-NEXT: .reg .b64 %rd<3>;
47+
; CHECK-EMPTY:
48+
; CHECK-NEXT: // %bb.0:
49+
; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_s16_param_0];
50+
; CHECK-NEXT: cvt.sat.s16.s64 %rs1, %rd1;
51+
; CHECK-NEXT: cvt.s64.s16 %rd2, %rs1;
52+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
53+
; CHECK-NEXT: ret;
54+
%v1 = call i64 @llvm.smax.i64(i64 %a, i64 -32768)
55+
%v2 = call i64 @llvm.smin.i64(i64 %v1, i64 32767)
56+
ret i64 %v2
57+
}
58+
59+
define i32 @trunc_ssat_i32_s16(i32 %a) {
60+
; CHECK-LABEL: trunc_ssat_i32_s16(
61+
; CHECK: {
62+
; CHECK-NEXT: .reg .b16 %rs<2>;
63+
; CHECK-NEXT: .reg .b32 %r<3>;
64+
; CHECK-EMPTY:
65+
; CHECK-NEXT: // %bb.0:
66+
; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_s16_param_0];
67+
; CHECK-NEXT: cvt.sat.s16.s32 %rs1, %r1;
68+
; CHECK-NEXT: cvt.s32.s16 %r2, %rs1;
69+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
70+
; CHECK-NEXT: ret;
71+
%v1 = call i32 @llvm.smax.i32(i32 %a, i32 -32768)
72+
%v2 = call i32 @llvm.smin.i32(i32 %v1, i32 32767)
73+
ret i32 %v2
74+
}
75+
76+
define i64 @trunc_ssat_i64_u8(i64 %a) {
77+
; CHECK-LABEL: trunc_ssat_i64_u8(
78+
; CHECK: {
79+
; CHECK-NEXT: .reg .b16 %rs<2>;
80+
; CHECK-NEXT: .reg .b64 %rd<3>;
81+
; CHECK-EMPTY:
82+
; CHECK-NEXT: // %bb.0:
83+
; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_u8_param_0];
84+
; CHECK-NEXT: cvt.sat.u8.u64 %rs1, %rd1;
85+
; CHECK-NEXT: cvt.u64.u16 %rd2, %rs1;
86+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
87+
; CHECK-NEXT: ret;
88+
%v1 = call i64 @llvm.smax.i64(i64 %a, i64 0)
89+
%v2 = call i64 @llvm.smin.i64(i64 %v1, i64 255)
90+
ret i64 %v2
91+
}
92+
93+
define i32 @trunc_ssat_i32_u8(i32 %a) {
94+
; CHECK-LABEL: trunc_ssat_i32_u8(
95+
; CHECK: {
96+
; CHECK-NEXT: .reg .b16 %rs<2>;
97+
; CHECK-NEXT: .reg .b32 %r<3>;
98+
; CHECK-EMPTY:
99+
; CHECK-NEXT: // %bb.0:
100+
; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_u8_param_0];
101+
; CHECK-NEXT: cvt.sat.u8.u32 %rs1, %r1;
102+
; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
103+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
104+
; CHECK-NEXT: ret;
105+
%v1 = call i32 @llvm.smax.i32(i32 %a, i32 0)
106+
%v2 = call i32 @llvm.smin.i32(i32 %v1, i32 255)
107+
ret i32 %v2
108+
}
109+
110+
define i16 @trunc_ssat_i16_u8(i16 %a) {
111+
; CHECK-LABEL: trunc_ssat_i16_u8(
112+
; CHECK: {
113+
; CHECK-NEXT: .reg .b16 %rs<3>;
114+
; CHECK-NEXT: .reg .b32 %r<2>;
115+
; CHECK-EMPTY:
116+
; CHECK-NEXT: // %bb.0:
117+
; CHECK-NEXT: ld.param.b16 %rs1, [trunc_ssat_i16_u8_param_0];
118+
; CHECK-NEXT: cvt.sat.u8.u16 %rs2, %rs1;
119+
; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
120+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
121+
; CHECK-NEXT: ret;
122+
%v1 = call i16 @llvm.smax.i16(i16 %a, i16 0)
123+
%v2 = call i16 @llvm.smin.i16(i16 %v1, i16 255)
124+
ret i16 %v2
125+
}
126+
127+
define i64 @trunc_ssat_i64_s8(i64 %a) {
128+
; CHECK-LABEL: trunc_ssat_i64_s8(
129+
; CHECK: {
130+
; CHECK-NEXT: .reg .b16 %rs<2>;
131+
; CHECK-NEXT: .reg .b64 %rd<3>;
132+
; CHECK-EMPTY:
133+
; CHECK-NEXT: // %bb.0:
134+
; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_s8_param_0];
135+
; CHECK-NEXT: cvt.sat.s8.s64 %rs1, %rd1;
136+
; CHECK-NEXT: cvt.s64.s16 %rd2, %rs1;
137+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
138+
; CHECK-NEXT: ret;
139+
%v1 = call i64 @llvm.smax.i64(i64 %a, i64 -128)
140+
%v2 = call i64 @llvm.smin.i64(i64 %v1, i64 127)
141+
ret i64 %v2
142+
}
143+
144+
define i32 @trunc_ssat_i32_s8(i32 %a) {
145+
; CHECK-LABEL: trunc_ssat_i32_s8(
146+
; CHECK: {
147+
; CHECK-NEXT: .reg .b16 %rs<2>;
148+
; CHECK-NEXT: .reg .b32 %r<3>;
149+
; CHECK-EMPTY:
150+
; CHECK-NEXT: // %bb.0:
151+
; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_s8_param_0];
152+
; CHECK-NEXT: cvt.sat.s8.s32 %rs1, %r1;
153+
; CHECK-NEXT: cvt.s32.s16 %r2, %rs1;
154+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
155+
; CHECK-NEXT: ret;
156+
%v1 = call i32 @llvm.smax.i32(i32 %a, i32 -128)
157+
%v2 = call i32 @llvm.smin.i32(i32 %v1, i32 127)
158+
ret i32 %v2
159+
}
160+
161+
define i16 @trunc_ssat_i16_s8(i16 %a) {
162+
; CHECK-LABEL: trunc_ssat_i16_s8(
163+
; CHECK: {
164+
; CHECK-NEXT: .reg .b16 %rs<3>;
165+
; CHECK-NEXT: .reg .b32 %r<2>;
166+
; CHECK-EMPTY:
167+
; CHECK-NEXT: // %bb.0:
168+
; CHECK-NEXT: ld.param.b16 %rs1, [trunc_ssat_i16_s8_param_0];
169+
; CHECK-NEXT: cvt.sat.s8.s16 %rs2, %rs1;
170+
; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
171+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
172+
; CHECK-NEXT: ret;
173+
%v1 = call i16 @llvm.smax.i16(i16 %a, i16 -128)
174+
%v2 = call i16 @llvm.smin.i16(i16 %v1, i16 127)
175+
ret i16 %v2
176+
}
177+

0 commit comments

Comments
 (0)