Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6537,3 +6537,102 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
break;
}
}

static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal,
const APInt &DemandedBits) {
APInt DemandedLHS = APInt(32, 0);
APInt DemandedRHS = APInt(32, 0);

for (unsigned I : llvm::seq(4)) {
if (DemandedBits.extractBits(8, I * 8).isZero())
continue;

APInt Sel = SelectorVal.extractBits(4, I * 4);
unsigned Idx = Sel.getLoBits(3).getZExtValue();
unsigned Sign = Sel.getHiBits(1).getZExtValue();

APInt &Src = Idx < 4 ? DemandedLHS : DemandedRHS;
unsigned ByteStart = (Idx % 4) * 8;
if (Sign)
Src.setBit(ByteStart + 7);
else
Src.setBits(ByteStart, ByteStart + 8);
}

return {DemandedLHS, DemandedRHS};
}

// Replace undef with 0 as this is easier for other optimizations such as
// known bits.
static SDValue canonicalizePRMTInput(SDValue Op, SelectionDAG &DAG) {
if (!Op)
return SDValue();
if (Op.isUndef())
return DAG.getConstant(0, SDLoc(), MVT::i32);
return Op;
}

static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
const APInt &DemandedBits,
SelectionDAG &DAG,
const TargetLowering &TLI,
unsigned Depth) {
assert(PRMT.getOpcode() == NVPTXISD::PRMT);
SDValue Op0 = PRMT.getOperand(0);
SDValue Op1 = PRMT.getOperand(1);
auto *SelectorConst = dyn_cast<ConstantSDNode>(PRMT.getOperand(2));
if (!SelectorConst)
return SDValue();

unsigned Mode = PRMT.getConstantOperandVal(3);
const APInt Selector = getPRMTSelector(SelectorConst->getAPIntValue(), Mode);

// Try to simplify the PRMT to one of the inputs if the used bytes are all
// from the same input in the correct order.
const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
const unsigned SelBits = (4 - LeadingBytes) * 4;
if (Selector.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits))
return Op0;
if (Selector.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits))
return Op1;

auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(Selector, DemandedBits);

// Attempt to avoid multi-use ops if we don't need anything from them.
SDValue DemandedOp0 =
TLI.SimplifyMultipleUseDemandedBits(Op0, DemandedLHS, DAG, Depth + 1);
SDValue DemandedOp1 =
TLI.SimplifyMultipleUseDemandedBits(Op1, DemandedRHS, DAG, Depth + 1);

DemandedOp0 = canonicalizePRMTInput(DemandedOp0, DAG);
DemandedOp1 = canonicalizePRMTInput(DemandedOp1, DAG);
if ((DemandedOp0 && DemandedOp0 != Op0) ||
(DemandedOp1 && DemandedOp1 != Op1)) {
Op0 = DemandedOp0 ? DemandedOp0 : Op0;
Op1 = DemandedOp1 ? DemandedOp1 : Op1;
return getPRMT(Op0, Op1, Selector.getZExtValue(), SDLoc(PRMT), DAG);
}

return SDValue();
}

bool NVPTXTargetLowering::SimplifyDemandedBitsForTargetNode(
SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
Known.resetAll();

switch (Op.getOpcode()) {
case NVPTXISD::PRMT:
if (SDValue Result = simplifyDemandedBitsForPRMT(Op, DemandedBits, TLO.DAG,
*this, Depth)) {
TLO.CombineTo(Op, Result);
return true;
}
break;
default:
break;
}

computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth);
return false;
}
5 changes: 5 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ class NVPTXTargetLowering : public TargetLowering {
const APInt &DemandedElts,
const SelectionDAG &DAG,
unsigned Depth = 0) const override;
bool SimplifyDemandedBitsForTargetNode(SDValue Op, const APInt &DemandedBits,
const APInt &DemandedElts,
KnownBits &Known,
TargetLoweringOpt &TLO,
unsigned Depth = 0) const override;

private:
const NVPTXSubtarget &STI; // cache the subtarget here
Expand Down
102 changes: 51 additions & 51 deletions llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
Original file line number Diff line number Diff line change
Expand Up @@ -181,32 +181,32 @@ define void @combine_v16i8(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr
; ENABLED-NEXT: prmt.b32 %r5, %r4, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r6, %r4, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r7, %r4, 0, 0x7771U;
; ENABLED-NEXT: prmt.b32 %r8, %r4, 0, 0x7770U;
; ENABLED-NEXT: prmt.b32 %r9, %r3, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r10, %r3, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r11, %r3, 0, 0x7771U;
; ENABLED-NEXT: prmt.b32 %r12, %r3, 0, 0x7770U;
; ENABLED-NEXT: prmt.b32 %r13, %r2, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r14, %r2, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r15, %r2, 0, 0x7771U;
; ENABLED-NEXT: prmt.b32 %r16, %r2, 0, 0x7770U;
; ENABLED-NEXT: prmt.b32 %r17, %r1, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r18, %r1, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r19, %r1, 0, 0x7771U;
; ENABLED-NEXT: prmt.b32 %r20, %r1, 0, 0x7770U;
; ENABLED-NEXT: prmt.b32 %r8, %r3, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r9, %r3, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r10, %r3, 0, 0x7771U;
; ENABLED-NEXT: prmt.b32 %r11, %r2, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r12, %r2, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r13, %r2, 0, 0x7771U;
; ENABLED-NEXT: prmt.b32 %r14, %r1, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r15, %r1, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r16, %r1, 0, 0x7771U;
; ENABLED-NEXT: ld.param.b64 %rd2, [combine_v16i8_param_1];
; ENABLED-NEXT: add.s32 %r21, %r20, %r19;
; ENABLED-NEXT: add.s32 %r22, %r21, %r18;
; ENABLED-NEXT: add.s32 %r23, %r22, %r17;
; ENABLED-NEXT: add.s32 %r24, %r23, %r16;
; ENABLED-NEXT: add.s32 %r25, %r24, %r15;
; ENABLED-NEXT: add.s32 %r26, %r25, %r14;
; ENABLED-NEXT: add.s32 %r27, %r26, %r13;
; ENABLED-NEXT: add.s32 %r28, %r27, %r12;
; ENABLED-NEXT: add.s32 %r29, %r28, %r11;
; ENABLED-NEXT: add.s32 %r30, %r29, %r10;
; ENABLED-NEXT: add.s32 %r31, %r30, %r9;
; ENABLED-NEXT: add.s32 %r32, %r31, %r8;
; ENABLED-NEXT: and.b32 %r17, %r1, 255;
; ENABLED-NEXT: and.b32 %r18, %r2, 255;
; ENABLED-NEXT: and.b32 %r19, %r3, 255;
; ENABLED-NEXT: and.b32 %r20, %r4, 255;
; ENABLED-NEXT: add.s32 %r21, %r17, %r16;
; ENABLED-NEXT: add.s32 %r22, %r21, %r15;
; ENABLED-NEXT: add.s32 %r23, %r22, %r14;
; ENABLED-NEXT: add.s32 %r24, %r23, %r18;
; ENABLED-NEXT: add.s32 %r25, %r24, %r13;
; ENABLED-NEXT: add.s32 %r26, %r25, %r12;
; ENABLED-NEXT: add.s32 %r27, %r26, %r11;
; ENABLED-NEXT: add.s32 %r28, %r27, %r19;
; ENABLED-NEXT: add.s32 %r29, %r28, %r10;
; ENABLED-NEXT: add.s32 %r30, %r29, %r9;
; ENABLED-NEXT: add.s32 %r31, %r30, %r8;
; ENABLED-NEXT: add.s32 %r32, %r31, %r20;
; ENABLED-NEXT: add.s32 %r33, %r32, %r7;
; ENABLED-NEXT: add.s32 %r34, %r33, %r6;
; ENABLED-NEXT: add.s32 %r35, %r34, %r5;
Expand Down Expand Up @@ -332,36 +332,36 @@ define void @combine_v16i8_unaligned(ptr noundef align 8 %ptr1, ptr noundef alig
; ENABLED-NEXT: prmt.b32 %r3, %r2, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r4, %r2, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r5, %r2, 0, 0x7771U;
; ENABLED-NEXT: prmt.b32 %r6, %r2, 0, 0x7770U;
; ENABLED-NEXT: prmt.b32 %r7, %r1, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r8, %r1, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r9, %r1, 0, 0x7771U;
; ENABLED-NEXT: prmt.b32 %r10, %r1, 0, 0x7770U;
; ENABLED-NEXT: prmt.b32 %r6, %r1, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r7, %r1, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r8, %r1, 0, 0x7771U;
; ENABLED-NEXT: ld.param.b64 %rd2, [combine_v16i8_unaligned_param_1];
; ENABLED-NEXT: ld.v2.b32 {%r11, %r12}, [%rd1+8];
; ENABLED-NEXT: prmt.b32 %r13, %r12, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r14, %r12, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r15, %r12, 0, 0x7771U;
; ENABLED-NEXT: prmt.b32 %r16, %r12, 0, 0x7770U;
; ENABLED-NEXT: prmt.b32 %r17, %r11, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r18, %r11, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r19, %r11, 0, 0x7771U;
; ENABLED-NEXT: prmt.b32 %r20, %r11, 0, 0x7770U;
; ENABLED-NEXT: add.s32 %r21, %r10, %r9;
; ENABLED-NEXT: add.s32 %r22, %r21, %r8;
; ENABLED-NEXT: add.s32 %r23, %r22, %r7;
; ENABLED-NEXT: add.s32 %r24, %r23, %r6;
; ENABLED-NEXT: ld.v2.b32 {%r9, %r10}, [%rd1+8];
; ENABLED-NEXT: prmt.b32 %r11, %r10, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r12, %r10, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r13, %r10, 0, 0x7771U;
; ENABLED-NEXT: prmt.b32 %r14, %r9, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r15, %r9, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r16, %r9, 0, 0x7771U;
; ENABLED-NEXT: and.b32 %r17, %r1, 255;
; ENABLED-NEXT: and.b32 %r18, %r2, 255;
; ENABLED-NEXT: and.b32 %r19, %r9, 255;
; ENABLED-NEXT: and.b32 %r20, %r10, 255;
; ENABLED-NEXT: add.s32 %r21, %r17, %r8;
; ENABLED-NEXT: add.s32 %r22, %r21, %r7;
; ENABLED-NEXT: add.s32 %r23, %r22, %r6;
; ENABLED-NEXT: add.s32 %r24, %r23, %r18;
; ENABLED-NEXT: add.s32 %r25, %r24, %r5;
; ENABLED-NEXT: add.s32 %r26, %r25, %r4;
; ENABLED-NEXT: add.s32 %r27, %r26, %r3;
; ENABLED-NEXT: add.s32 %r28, %r27, %r20;
; ENABLED-NEXT: add.s32 %r29, %r28, %r19;
; ENABLED-NEXT: add.s32 %r30, %r29, %r18;
; ENABLED-NEXT: add.s32 %r31, %r30, %r17;
; ENABLED-NEXT: add.s32 %r32, %r31, %r16;
; ENABLED-NEXT: add.s32 %r33, %r32, %r15;
; ENABLED-NEXT: add.s32 %r34, %r33, %r14;
; ENABLED-NEXT: add.s32 %r35, %r34, %r13;
; ENABLED-NEXT: add.s32 %r28, %r27, %r19;
; ENABLED-NEXT: add.s32 %r29, %r28, %r16;
; ENABLED-NEXT: add.s32 %r30, %r29, %r15;
; ENABLED-NEXT: add.s32 %r31, %r30, %r14;
; ENABLED-NEXT: add.s32 %r32, %r31, %r20;
; ENABLED-NEXT: add.s32 %r33, %r32, %r13;
; ENABLED-NEXT: add.s32 %r34, %r33, %r12;
; ENABLED-NEXT: add.s32 %r35, %r34, %r11;
; ENABLED-NEXT: st.b32 [%rd2], %r35;
; ENABLED-NEXT: ret;
;
Expand Down
71 changes: 34 additions & 37 deletions llvm/test/CodeGen/NVPTX/extractelement.ll
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,22 @@ define i16 @test_v4i8(i32 %a) {
; CHECK-LABEL: test_v4i8(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<8>;
; CHECK-NEXT: .reg .b32 %r<7>;
; CHECK-NEXT: .reg .b32 %r<6>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [test_v4i8_param_0];
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x8880U;
; CHECK-NEXT: cvt.u16.u32 %rs1, %r2;
; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x9991U;
; CHECK-NEXT: cvt.u16.u32 %rs2, %r3;
; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0xaaa2U;
; CHECK-NEXT: cvt.u16.u32 %rs3, %r4;
; CHECK-NEXT: prmt.b32 %r5, %r1, 0, 0xbbb3U;
; CHECK-NEXT: cvt.u16.u32 %rs4, %r5;
; CHECK-NEXT: cvt.s8.s32 %rs1, %r1;
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x9991U;
; CHECK-NEXT: cvt.u16.u32 %rs2, %r2;
; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0xaaa2U;
; CHECK-NEXT: cvt.u16.u32 %rs3, %r3;
; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0xbbb3U;
; CHECK-NEXT: cvt.u16.u32 %rs4, %r4;
; CHECK-NEXT: add.s16 %rs5, %rs1, %rs2;
; CHECK-NEXT: add.s16 %rs6, %rs3, %rs4;
; CHECK-NEXT: add.s16 %rs7, %rs5, %rs6;
; CHECK-NEXT: cvt.u32.u16 %r6, %rs7;
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
; CHECK-NEXT: cvt.u32.u16 %r5, %rs7;
; CHECK-NEXT: st.param.b32 [func_retval0], %r5;
; CHECK-NEXT: ret;
%v = bitcast i32 %a to <4 x i8>
%r0 = extractelement <4 x i8> %v, i64 0
Expand All @@ -96,7 +95,7 @@ define i32 @test_v4i8_s32(i32 %a) {
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [test_v4i8_s32_param_0];
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x8880U;
; CHECK-NEXT: cvt.s32.s8 %r2, %r1;
; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x9991U;
; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0xaaa2U;
; CHECK-NEXT: prmt.b32 %r5, %r1, 0, 0xbbb3U;
Expand Down Expand Up @@ -127,12 +126,12 @@ define i32 @test_v4i8_u32(i32 %a) {
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [test_v4i8_u32_param_0];
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7770U;
; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x7771U;
; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0x7772U;
; CHECK-NEXT: prmt.b32 %r5, %r1, 0, 0x7773U;
; CHECK-NEXT: add.s32 %r6, %r2, %r3;
; CHECK-NEXT: add.s32 %r7, %r4, %r5;
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7771U;
; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x7772U;
; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0x7773U;
; CHECK-NEXT: and.b32 %r5, %r1, 255;
; CHECK-NEXT: add.s32 %r6, %r5, %r2;
; CHECK-NEXT: add.s32 %r7, %r3, %r4;
; CHECK-NEXT: add.s32 %r8, %r6, %r7;
; CHECK-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-NEXT: ret;
Expand All @@ -157,35 +156,33 @@ define i16 @test_v8i8(i64 %a) {
; CHECK-LABEL: test_v8i8(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<16>;
; CHECK-NEXT: .reg .b32 %r<12>;
; CHECK-NEXT: .reg .b32 %r<10>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.v2.b32 {%r1, %r2}, [test_v8i8_param_0];
; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x8880U;
; CHECK-NEXT: cvt.u16.u32 %rs1, %r3;
; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0x9991U;
; CHECK-NEXT: cvt.u16.u32 %rs2, %r4;
; CHECK-NEXT: prmt.b32 %r5, %r1, 0, 0xaaa2U;
; CHECK-NEXT: cvt.u16.u32 %rs3, %r5;
; CHECK-NEXT: prmt.b32 %r6, %r1, 0, 0xbbb3U;
; CHECK-NEXT: cvt.u16.u32 %rs4, %r6;
; CHECK-NEXT: prmt.b32 %r7, %r2, 0, 0x8880U;
; CHECK-NEXT: cvt.u16.u32 %rs5, %r7;
; CHECK-NEXT: prmt.b32 %r8, %r2, 0, 0x9991U;
; CHECK-NEXT: cvt.u16.u32 %rs6, %r8;
; CHECK-NEXT: prmt.b32 %r9, %r2, 0, 0xaaa2U;
; CHECK-NEXT: cvt.u16.u32 %rs7, %r9;
; CHECK-NEXT: prmt.b32 %r10, %r2, 0, 0xbbb3U;
; CHECK-NEXT: cvt.u16.u32 %rs8, %r10;
; CHECK-NEXT: cvt.s8.s32 %rs1, %r1;
; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x9991U;
; CHECK-NEXT: cvt.u16.u32 %rs2, %r3;
; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0xaaa2U;
; CHECK-NEXT: cvt.u16.u32 %rs3, %r4;
; CHECK-NEXT: prmt.b32 %r5, %r1, 0, 0xbbb3U;
; CHECK-NEXT: cvt.u16.u32 %rs4, %r5;
; CHECK-NEXT: cvt.s8.s32 %rs5, %r2;
; CHECK-NEXT: prmt.b32 %r6, %r2, 0, 0x9991U;
; CHECK-NEXT: cvt.u16.u32 %rs6, %r6;
; CHECK-NEXT: prmt.b32 %r7, %r2, 0, 0xaaa2U;
; CHECK-NEXT: cvt.u16.u32 %rs7, %r7;
; CHECK-NEXT: prmt.b32 %r8, %r2, 0, 0xbbb3U;
; CHECK-NEXT: cvt.u16.u32 %rs8, %r8;
; CHECK-NEXT: add.s16 %rs9, %rs1, %rs2;
; CHECK-NEXT: add.s16 %rs10, %rs3, %rs4;
; CHECK-NEXT: add.s16 %rs11, %rs5, %rs6;
; CHECK-NEXT: add.s16 %rs12, %rs7, %rs8;
; CHECK-NEXT: add.s16 %rs13, %rs9, %rs10;
; CHECK-NEXT: add.s16 %rs14, %rs11, %rs12;
; CHECK-NEXT: add.s16 %rs15, %rs13, %rs14;
; CHECK-NEXT: cvt.u32.u16 %r11, %rs15;
; CHECK-NEXT: st.param.b32 [func_retval0], %r11;
; CHECK-NEXT: cvt.u32.u16 %r9, %rs15;
; CHECK-NEXT: st.param.b32 [func_retval0], %r9;
; CHECK-NEXT: ret;
%v = bitcast i64 %a to <8 x i8>
%r0 = extractelement <8 x i8> %v, i64 0
Expand Down
Loading
Loading