Skip to content

Commit e2a60bb

Browse files
committed
[NVPTX] Legalize ctpop and ctlz in operation legalization
1 parent e661957 commit e2a60bb

File tree

4 files changed

+68
-94
lines changed

4 files changed

+68
-94
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -766,16 +766,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
766766
// Custom handling for i8 intrinsics
767767
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
768768

769-
for (const auto& Ty : {MVT::i16, MVT::i32, MVT::i64}) {
770-
setOperationAction(ISD::ABS, Ty, Legal);
771-
setOperationAction(ISD::SMIN, Ty, Legal);
772-
setOperationAction(ISD::SMAX, Ty, Legal);
773-
setOperationAction(ISD::UMIN, Ty, Legal);
774-
setOperationAction(ISD::UMAX, Ty, Legal);
769+
setOperationAction({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
770+
{MVT::i16, MVT::i32, MVT::i64}, Legal);
775771

776-
setOperationAction(ISD::CTPOP, Ty, Legal);
777-
setOperationAction(ISD::CTLZ, Ty, Legal);
778-
}
772+
setOperationAction({ISD::CTPOP, ISD::CTLZ}, MVT::i32, Legal);
773+
setOperationAction({ISD::CTPOP, ISD::CTLZ}, {MVT::i16, MVT::i64}, Custom);
779774

780775
setI16x2OperationAction(ISD::ABS, MVT::v2i16, Legal, Custom);
781776
setI16x2OperationAction(ISD::SMIN, MVT::v2i16, Legal, Custom);
@@ -2750,6 +2745,42 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
27502745
return Op;
27512746
}
27522747

2748+
static SDValue lowerCTPOP(SDValue Op, SelectionDAG &DAG) {
2749+
SDValue V = Op->getOperand(0);
2750+
SDLoc DL(Op);
2751+
2752+
if (V.getValueType() == MVT::i16) {
2753+
SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, V);
2754+
SDValue CT = DAG.getNode(ISD::CTPOP, DL, MVT::i32, Zext);
2755+
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, CT, SDNodeFlags::NoWrap);
2756+
}
2757+
if (V.getValueType() == MVT::i64) {
2758+
SDValue CT = DAG.getNode(ISD::CTPOP, DL, MVT::i32, V);
2759+
return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CT);
2760+
}
2761+
llvm_unreachable("Unexpected CTPOP type to legalize");
2762+
}
2763+
2764+
static SDValue lowerCTLZ(SDValue Op, SelectionDAG &DAG) {
2765+
SDValue V = Op->getOperand(0);
2766+
SDLoc DL(Op);
2767+
2768+
if (V.getValueType() == MVT::i16) {
2769+
SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, V);
2770+
SDValue CT = DAG.getNode(ISD::CTLZ, DL, MVT::i32, Zext);
2771+
SDValue Sub =
2772+
DAG.getNode(ISD::ADD, DL, MVT::i32, CT,
2773+
DAG.getConstant(APInt(32, -16, true), DL, MVT::i32),
2774+
SDNodeFlags::NoSignedWrap);
2775+
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Sub, SDNodeFlags::NoWrap);
2776+
}
2777+
if (V.getValueType() == MVT::i64) {
2778+
SDValue CT = DAG.getNode(ISD::CTLZ, DL, MVT::i32, V);
2779+
return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CT);
2780+
}
2781+
llvm_unreachable("Unexpected CTLZ type to legalize");
2782+
}
2783+
27532784
SDValue
27542785
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
27552786
switch (Op.getOpcode()) {
@@ -2835,6 +2866,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28352866
case ISD::FMUL:
28362867
// Used only for bf16 on SM80, where we select fma for non-ftz operation
28372868
return PromoteBinOpIfF32FTZ(Op, DAG);
2869+
case ISD::CTPOP:
2870+
return lowerCTPOP(Op, DAG);
2871+
case ISD::CTLZ:
2872+
return lowerCTLZ(Op, DAG);
28382873

28392874
default:
28402875
llvm_unreachable("Custom lowering not defined for operation");

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 11 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3267,69 +3267,20 @@ def : Pat<(i32 (int_nvvm_fshr_clamp i32:$hi, i32:$lo, i32:$amt)),
32673267
def : Pat<(i32 (int_nvvm_fshr_clamp i32:$hi, i32:$lo, (i32 imm:$amt))),
32683268
(SHF_R_CLAMP_i $lo, $hi, imm:$amt)>;
32693269

3270-
// Count leading zeros
32713270
let hasSideEffects = false in {
3272-
def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a),
3273-
"clz.b32 \t$d, $a;", []>;
3274-
def CLZr64 : NVPTXInst<(outs Int32Regs:$d), (ins Int64Regs:$a),
3275-
"clz.b64 \t$d, $a;", []>;
3271+
foreach RT = [I32RT, I64RT] in {
3272+
// Count leading zeros
3273+
def CLZr # RT.Size : NVPTXInst<(outs Int32Regs:$d), (ins RT.RC:$a),
3274+
"clz.b" # RT.Size # " \t$d, $a;",
3275+
[(set i32:$d, (ctlz RT.Ty:$a))]>;
3276+
3277+
// Population count
3278+
def POPCr # RT.Size : NVPTXInst<(outs Int32Regs:$d), (ins RT.RC:$a),
3279+
"popc.b" # RT.Size # " \t$d, $a;",
3280+
[(set i32:$d, (ctpop RT.Ty:$a))]>;
3281+
}
32763282
}
32773283

3278-
// 32-bit has a direct PTX instruction
3279-
def : Pat<(i32 (ctlz i32:$a)), (CLZr32 $a)>;
3280-
3281-
// The return type of the ctlz ISD node is the same as its input, but the PTX
3282-
// ctz instruction always returns a 32-bit value. For ctlz.i64, convert the
3283-
// ptx value to 64 bits to match the ISD node's semantics, unless we know we're
3284-
// truncating back down to 32 bits.
3285-
def : Pat<(i64 (ctlz i64:$a)), (CVT_u64_u32 (CLZr64 $a), CvtNONE)>;
3286-
def : Pat<(i32 (trunc (i64 (ctlz i64:$a)))), (CLZr64 $a)>;
3287-
3288-
// For 16-bit ctlz, we zero-extend to 32-bit, perform the count, then trunc the
3289-
// result back to 16-bits if necessary. We also need to subtract 16 because
3290-
// the high-order 16 zeros were counted.
3291-
//
3292-
// TODO: NVPTX has a mov.b32 b32reg, {imm, b16reg} instruction, which we could
3293-
// use to save one SASS instruction (on sm_35 anyway):
3294-
//
3295-
// mov.b32 $tmp, {0xffff, $a}
3296-
// ctlz.b32 $result, $tmp
3297-
//
3298-
// That is, instead of zero-extending the input to 32 bits, we'd "one-extend"
3299-
// and then ctlz that value. This way we don't have to subtract 16 from the
3300-
// result. Unfortunately today we don't have a way to generate
3301-
// "mov b32reg, {b16imm, b16reg}", so we don't do this optimization.
3302-
def : Pat<(i16 (ctlz i16:$a)),
3303-
(SUBi16ri (CVT_u16_u32
3304-
(CLZr32 (CVT_u32_u16 $a, CvtNONE)), CvtNONE), 16)>;
3305-
def : Pat<(i32 (zext (i16 (ctlz i16:$a)))),
3306-
(SUBi32ri (CLZr32 (CVT_u32_u16 $a, CvtNONE)), 16)>;
3307-
3308-
// Population count
3309-
let hasSideEffects = false in {
3310-
def POPCr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a),
3311-
"popc.b32 \t$d, $a;", []>;
3312-
def POPCr64 : NVPTXInst<(outs Int32Regs:$d), (ins Int64Regs:$a),
3313-
"popc.b64 \t$d, $a;", []>;
3314-
}
3315-
3316-
// 32-bit has a direct PTX instruction
3317-
def : Pat<(i32 (ctpop i32:$a)), (POPCr32 $a)>;
3318-
3319-
// For 64-bit, the result in PTX is actually 32-bit so we zero-extend to 64-bit
3320-
// to match the LLVM semantics. Just as with ctlz.i64, we provide a second
3321-
// pattern that avoids the type conversion if we're truncating the result to
3322-
// i32 anyway.
3323-
def : Pat<(ctpop i64:$a), (CVT_u64_u32 (POPCr64 $a), CvtNONE)>;
3324-
def : Pat<(i32 (trunc (i64 (ctpop i64:$a)))), (POPCr64 $a)>;
3325-
3326-
// For 16-bit, we zero-extend to 32-bit, then trunc the result back to 16-bits.
3327-
// If we know that we're storing into an i32, we can avoid the final trunc.
3328-
def : Pat<(ctpop i16:$a),
3329-
(CVT_u16_u32 (POPCr32 (CVT_u32_u16 $a, CvtNONE)), CvtNONE)>;
3330-
def : Pat<(i32 (zext (i16 (ctpop i16:$a)))),
3331-
(POPCr32 (CVT_u32_u16 $a, CvtNONE))>;
3332-
33333284
// fpround f32 -> f16
33343285
def : Pat<(f16 (fpround f32:$a)),
33353286
(CVT_f16_f32 $a, CvtRN)>;

llvm/test/CodeGen/NVPTX/ctlz.ll

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,12 @@ define i32 @myctlz64_as_32_2(i64 %a) {
112112
define i16 @myctlz_ret16(i16 %a) {
113113
; CHECK-LABEL: myctlz_ret16(
114114
; CHECK: {
115-
; CHECK-NEXT: .reg .b16 %rs<2>;
116115
; CHECK-NEXT: .reg .b32 %r<4>;
117116
; CHECK-EMPTY:
118117
; CHECK-NEXT: // %bb.0:
119-
; CHECK-NEXT: ld.param.u16 %rs1, [myctlz_ret16_param_0];
120-
; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
118+
; CHECK-NEXT: ld.param.u16 %r1, [myctlz_ret16_param_0];
121119
; CHECK-NEXT: clz.b32 %r2, %r1;
122-
; CHECK-NEXT: sub.s32 %r3, %r2, 16;
120+
; CHECK-NEXT: add.s32 %r3, %r2, -16;
123121
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
124122
; CHECK-NEXT: ret;
125123
%val = call i16 @llvm.ctlz.i16(i16 %a, i1 false) readnone
@@ -128,14 +126,12 @@ define i16 @myctlz_ret16(i16 %a) {
128126
define i16 @myctlz_ret16_2(i16 %a) {
129127
; CHECK-LABEL: myctlz_ret16_2(
130128
; CHECK: {
131-
; CHECK-NEXT: .reg .b16 %rs<2>;
132129
; CHECK-NEXT: .reg .b32 %r<4>;
133130
; CHECK-EMPTY:
134131
; CHECK-NEXT: // %bb.0:
135-
; CHECK-NEXT: ld.param.u16 %rs1, [myctlz_ret16_2_param_0];
136-
; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
132+
; CHECK-NEXT: ld.param.u16 %r1, [myctlz_ret16_2_param_0];
137133
; CHECK-NEXT: clz.b32 %r2, %r1;
138-
; CHECK-NEXT: sub.s32 %r3, %r2, 16;
134+
; CHECK-NEXT: add.s32 %r3, %r2, -16;
139135
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
140136
; CHECK-NEXT: ret;
141137
%val = call i16 @llvm.ctlz.i16(i16 %a, i1 true) readnone
@@ -147,18 +143,15 @@ define i16 @myctlz_ret16_2(i16 %a) {
147143
define void @myctlz_store16(i16 %a, ptr %b) {
148144
; CHECK-LABEL: myctlz_store16(
149145
; CHECK: {
150-
; CHECK-NEXT: .reg .b16 %rs<4>;
151-
; CHECK-NEXT: .reg .b32 %r<3>;
146+
; CHECK-NEXT: .reg .b32 %r<4>;
152147
; CHECK-NEXT: .reg .b64 %rd<2>;
153148
; CHECK-EMPTY:
154149
; CHECK-NEXT: // %bb.0:
155-
; CHECK-NEXT: ld.param.u16 %rs1, [myctlz_store16_param_0];
156-
; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
150+
; CHECK-NEXT: ld.param.u16 %r1, [myctlz_store16_param_0];
157151
; CHECK-NEXT: clz.b32 %r2, %r1;
158-
; CHECK-NEXT: cvt.u16.u32 %rs2, %r2;
159-
; CHECK-NEXT: sub.s16 %rs3, %rs2, 16;
152+
; CHECK-NEXT: add.s32 %r3, %r2, -16;
160153
; CHECK-NEXT: ld.param.u64 %rd1, [myctlz_store16_param_1];
161-
; CHECK-NEXT: st.u16 [%rd1], %rs3;
154+
; CHECK-NEXT: st.u16 [%rd1], %r3;
162155
; CHECK-NEXT: ret;
163156
%val = call i16 @llvm.ctlz.i16(i16 %a, i1 false) readnone
164157
store i16 %val, ptr %b
@@ -167,18 +160,15 @@ define void @myctlz_store16(i16 %a, ptr %b) {
167160
define void @myctlz_store16_2(i16 %a, ptr %b) {
168161
; CHECK-LABEL: myctlz_store16_2(
169162
; CHECK: {
170-
; CHECK-NEXT: .reg .b16 %rs<4>;
171-
; CHECK-NEXT: .reg .b32 %r<3>;
163+
; CHECK-NEXT: .reg .b32 %r<4>;
172164
; CHECK-NEXT: .reg .b64 %rd<2>;
173165
; CHECK-EMPTY:
174166
; CHECK-NEXT: // %bb.0:
175-
; CHECK-NEXT: ld.param.u16 %rs1, [myctlz_store16_2_param_0];
176-
; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
167+
; CHECK-NEXT: ld.param.u16 %r1, [myctlz_store16_2_param_0];
177168
; CHECK-NEXT: clz.b32 %r2, %r1;
178-
; CHECK-NEXT: cvt.u16.u32 %rs2, %r2;
179-
; CHECK-NEXT: sub.s16 %rs3, %rs2, 16;
169+
; CHECK-NEXT: add.s32 %r3, %r2, -16;
180170
; CHECK-NEXT: ld.param.u64 %rd1, [myctlz_store16_2_param_1];
181-
; CHECK-NEXT: st.u16 [%rd1], %rs3;
171+
; CHECK-NEXT: st.u16 [%rd1], %r3;
182172
; CHECK-NEXT: ret;
183173
%val = call i16 @llvm.ctlz.i16(i16 %a, i1 false) readnone
184174
store i16 %val, ptr %b

llvm/test/CodeGen/NVPTX/intrinsics.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,10 @@ define void @test_popc16(i16 %a, ptr %b) {
150150
define i32 @test_popc16_to_32(i16 %a) {
151151
; CHECK-LABEL: test_popc16_to_32(
152152
; CHECK: {
153-
; CHECK-NEXT: .reg .b16 %rs<2>;
154153
; CHECK-NEXT: .reg .b32 %r<3>;
155154
; CHECK-EMPTY:
156155
; CHECK-NEXT: // %bb.0:
157-
; CHECK-NEXT: ld.param.u16 %rs1, [test_popc16_to_32_param_0];
158-
; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
156+
; CHECK-NEXT: ld.param.u16 %r1, [test_popc16_to_32_param_0];
159157
; CHECK-NEXT: popc.b32 %r2, %r1;
160158
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
161159
; CHECK-NEXT: ret;

0 commit comments

Comments
 (0)