Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Utils/Local.h"

using namespace llvm;

Expand Down Expand Up @@ -62,10 +63,74 @@ class RISCVCodeGenPrepare : public FunctionPass,

} // end anonymous namespace

// Try to optimize (i64 (and (zext/sext (i32 X), C1))) if C1 has bit 31 set,
// but bits 63:32 are zero. If we know that bit 31 of X is 0, we can fill
// the upper 32 bits with ones.
// InstCombinerImpl::transformZExtICmp will narrow a zext of an icmp with a
// truncation. But RVV doesn't have truncation instructions for more than twice
// the bitwidth.
//
// E.g. trunc <vscale x 1 x i64> %x to <vscale x 1 x i8> will generate:
//
// vsetvli a0, zero, e32, m2, ta, ma
// vnsrl.wi v12, v8, 0
// vsetvli zero, zero, e16, m1, ta, ma
// vnsrl.wi v8, v12, 0
// vsetvli zero, zero, e8, mf2, ta, ma
// vnsrl.wi v8, v8, 0
//
// So reverse the combine so we generate an vmseq/vmsne again:
//
// and (lshr (trunc X), ShAmt), 1
// -->
// zext (icmp ne (and X, (1 << ShAmt)), 0)
//
// and (lshr (not (trunc X)), ShAmt), 1
// -->
// zext (icmp eq (and X, (1 << ShAmt)), 0)
static bool reverseZExtICmpCombine(BinaryOperator &BO) {
using namespace PatternMatch;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be checking for that the vector extensions are enabled?


assert(BO.getOpcode() == BinaryOperator::And);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could generalize this using demanded bits, but I'm not sure if doing so is actually worthwhile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for what it's worth the case I'm specifically trying to catch in InstCombine seems to only match ands anyway:

if (Cmp->isEquality()) {
// Test if a bit is clear/set using a shifted-one mask:
// zext (icmp eq (and X, (1 << ShAmt)), 0) --> and (lshr (not X), ShAmt), 1
// zext (icmp ne (and X, (1 << ShAmt)), 0) --> and (lshr X, ShAmt), 1
Value *X, *ShAmt;
if (Cmp->hasOneUse() && match(Cmp->getOperand(1), m_ZeroInt()) &&
match(Cmp->getOperand(0),
m_OneUse(m_c_And(m_Shl(m_One(), m_Value(ShAmt)), m_Value(X))))) {


if (!BO.getType()->isVectorTy())
return false;
const APInt *ShAmt;
Value *Inner;
if (!match(&BO,
m_And(m_OneUse(m_LShr(m_OneUse(m_Value(Inner)), m_APInt(ShAmt))),
m_One())))
return false;

Value *X;
bool IsNot;
if (match(Inner, m_Not(m_Trunc(m_Value(X)))))
IsNot = true;
else if (match(Inner, m_Trunc(m_Value(X))))
IsNot = false;
else
return false;

if (BO.getType()->getScalarSizeInBits() >=
X->getType()->getScalarSizeInBits() / 2)
return false;

IRBuilder<> Builder(&BO);
Value *Res = Builder.CreateAnd(
X, ConstantInt::get(X->getType(), 1 << ShAmt->getZExtValue()));
Res = Builder.CreateICmp(IsNot ? CmpInst::Predicate::ICMP_EQ
: CmpInst::Predicate::ICMP_NE,
Res, ConstantInt::get(X->getType(), 0));
Res = Builder.CreateZExt(Res, BO.getType());
BO.replaceAllUsesWith(Res);
RecursivelyDeleteTriviallyDeadInstructions(&BO);
return true;
}

bool RISCVCodeGenPrepare::visitAnd(BinaryOperator &BO) {
if (reverseZExtICmpCombine(BO))
return true;

// Try to optimize (i64 (and (zext/sext (i32 X), C1))) if C1 has bit 31 set,
// but bits 63:32 are zero. If we know that bit 31 of X is 0, we can fill
// the upper 32 bits with ones.
if (!ST->is64Bit())
return false;

Expand Down
81 changes: 81 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,84 @@ vector.body: ; preds = %vector.body, %entry
for.cond.cleanup: ; preds = %vector.body
ret float %red
}

define <vscale x 1 x i8> @reverse_zexticmp_i16(<vscale x 1 x i16> %x) {
; CHECK-LABEL: reverse_zexticmp_i16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e8, mf8, ta, ma
; CHECK-NEXT: vnsrl.wi v8, v8, 0
; CHECK-NEXT: vsrl.vi v8, v8, 2
; CHECK-NEXT: vand.vi v8, v8, 1
; CHECK-NEXT: ret
%1 = trunc <vscale x 1 x i16> %x to <vscale x 1 x i8>
%2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
%3 = and <vscale x 1 x i8> %2, splat (i8 1)
ret <vscale x 1 x i8> %3
}

define <vscale x 1 x i8> @reverse_zexticmp_i32(<vscale x 1 x i32> %x) {
; CHECK-LABEL: reverse_zexticmp_i32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
; CHECK-NEXT: vand.vi v8, v8, 4
; CHECK-NEXT: vmsne.vi v0, v8, 0
; CHECK-NEXT: vsetvli zero, zero, e8, mf8, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vmerge.vim v8, v8, 1, v0
; CHECK-NEXT: ret
%1 = trunc <vscale x 1 x i32> %x to <vscale x 1 x i8>
%2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
%3 = and <vscale x 1 x i8> %2, splat (i8 1)
ret <vscale x 1 x i8> %3
}

define <vscale x 1 x i8> @reverse_zexticmp_neg_i32(<vscale x 1 x i32> %x) {
; CHECK-LABEL: reverse_zexticmp_neg_i32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
; CHECK-NEXT: vand.vi v8, v8, 4
; CHECK-NEXT: vmseq.vi v0, v8, 0
; CHECK-NEXT: vsetvli zero, zero, e8, mf8, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vmerge.vim v8, v8, 1, v0
; CHECK-NEXT: ret
%1 = trunc <vscale x 1 x i32> %x to <vscale x 1 x i8>
%2 = xor <vscale x 1 x i8> %1, splat (i8 -1)
%3 = lshr <vscale x 1 x i8> %2, splat (i8 2)
%4 = and <vscale x 1 x i8> %3, splat (i8 1)
ret <vscale x 1 x i8> %4
}

define <vscale x 1 x i8> @reverse_zexticmp_i64(<vscale x 1 x i64> %x) {
; CHECK-LABEL: reverse_zexticmp_i64:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vand.vi v8, v8, 4
; CHECK-NEXT: vmsne.vi v0, v8, 0
; CHECK-NEXT: vsetvli zero, zero, e8, mf8, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vmerge.vim v8, v8, 1, v0
; CHECK-NEXT: ret
%1 = trunc <vscale x 1 x i64> %x to <vscale x 1 x i8>
%2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
%3 = and <vscale x 1 x i8> %2, splat (i8 1)
ret <vscale x 1 x i8> %3
}

define <vscale x 1 x i8> @reverse_zexticmp_neg_i64(<vscale x 1 x i64> %x) {
; CHECK-LABEL: reverse_zexticmp_neg_i64:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vand.vi v8, v8, 4
; CHECK-NEXT: vmseq.vi v0, v8, 0
; CHECK-NEXT: vsetvli zero, zero, e8, mf8, ta, ma
; CHECK-NEXT: vmv.v.i v8, 0
; CHECK-NEXT: vmerge.vim v8, v8, 1, v0
; CHECK-NEXT: ret
%1 = trunc <vscale x 1 x i64> %x to <vscale x 1 x i8>
%2 = xor <vscale x 1 x i8> %1, splat (i8 -1)
%3 = lshr <vscale x 1 x i8> %2, splat (i8 2)
%4 = and <vscale x 1 x i8> %3, splat (i8 1)
ret <vscale x 1 x i8> %4
}

72 changes: 72 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,75 @@ vector.body: ; preds = %vector.body, %entry
for.cond.cleanup: ; preds = %vector.body
ret float %red
}

define <vscale x 1 x i8> @reverse_zexticmp_i16(<vscale x 1 x i16> %x) {
; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_i16(
; CHECK-SAME: <vscale x 1 x i16> [[X:%.*]]) #[[ATTR2]] {
; CHECK-NEXT: [[TMP1:%.*]] = trunc <vscale x 1 x i16> [[X]] to <vscale x 1 x i8>
; CHECK-NEXT: [[TMP2:%.*]] = lshr <vscale x 1 x i8> [[TMP1]], splat (i8 2)
; CHECK-NEXT: [[TMP3:%.*]] = and <vscale x 1 x i8> [[TMP2]], splat (i8 1)
; CHECK-NEXT: ret <vscale x 1 x i8> [[TMP3]]
;
%1 = trunc <vscale x 1 x i16> %x to <vscale x 1 x i8>
%2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
%3 = and <vscale x 1 x i8> %2, splat (i8 1)
ret <vscale x 1 x i8> %3
}

define <vscale x 1 x i8> @reverse_zexticmp_i32(<vscale x 1 x i32> %x) {
; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_i32(
; CHECK-SAME: <vscale x 1 x i32> [[X:%.*]]) #[[ATTR2]] {
; CHECK-NEXT: [[TMP1:%.*]] = and <vscale x 1 x i32> [[X]], splat (i32 4)
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <vscale x 1 x i32> [[TMP1]], zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = zext <vscale x 1 x i1> [[TMP2]] to <vscale x 1 x i8>
; CHECK-NEXT: ret <vscale x 1 x i8> [[TMP3]]
;
%1 = trunc <vscale x 1 x i32> %x to <vscale x 1 x i8>
%2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
%3 = and <vscale x 1 x i8> %2, splat (i8 1)
ret <vscale x 1 x i8> %3
}

define <vscale x 1 x i8> @reverse_zexticmp_neg_i32(<vscale x 1 x i32> %x) {
; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_neg_i32(
; CHECK-SAME: <vscale x 1 x i32> [[X:%.*]]) #[[ATTR2]] {
; CHECK-NEXT: [[TMP1:%.*]] = and <vscale x 1 x i32> [[X]], splat (i32 4)
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq <vscale x 1 x i32> [[TMP1]], zeroinitializer
; CHECK-NEXT: [[TMP4:%.*]] = zext <vscale x 1 x i1> [[TMP2]] to <vscale x 1 x i8>
; CHECK-NEXT: ret <vscale x 1 x i8> [[TMP4]]
;
%1 = trunc <vscale x 1 x i32> %x to <vscale x 1 x i8>
%2 = xor <vscale x 1 x i8> %1, splat (i8 -1)
%3 = lshr <vscale x 1 x i8> %2, splat (i8 2)
%4 = and <vscale x 1 x i8> %3, splat (i8 1)
ret <vscale x 1 x i8> %4
}

define <vscale x 1 x i8> @reverse_zexticmp_i64(<vscale x 1 x i64> %x) {
; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_i64(
; CHECK-SAME: <vscale x 1 x i64> [[X:%.*]]) #[[ATTR2]] {
; CHECK-NEXT: [[TMP1:%.*]] = and <vscale x 1 x i64> [[X]], splat (i64 4)
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <vscale x 1 x i64> [[TMP1]], zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = zext <vscale x 1 x i1> [[TMP2]] to <vscale x 1 x i8>
; CHECK-NEXT: ret <vscale x 1 x i8> [[TMP3]]
;
%1 = trunc <vscale x 1 x i64> %x to <vscale x 1 x i8>
%2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
%3 = and <vscale x 1 x i8> %2, splat (i8 1)
ret <vscale x 1 x i8> %3
}

define <vscale x 1 x i8> @reverse_zexticmp_neg_i64(<vscale x 1 x i64> %x) {
; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_neg_i64(
; CHECK-SAME: <vscale x 1 x i64> [[X:%.*]]) #[[ATTR2]] {
; CHECK-NEXT: [[TMP1:%.*]] = and <vscale x 1 x i64> [[X]], splat (i64 4)
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq <vscale x 1 x i64> [[TMP1]], zeroinitializer
; CHECK-NEXT: [[TMP4:%.*]] = zext <vscale x 1 x i1> [[TMP2]] to <vscale x 1 x i8>
; CHECK-NEXT: ret <vscale x 1 x i8> [[TMP4]]
;
%1 = trunc <vscale x 1 x i64> %x to <vscale x 1 x i8>
%2 = xor <vscale x 1 x i8> %1, splat (i8 -1)
%3 = lshr <vscale x 1 x i8> %2, splat (i8 2)
%4 = and <vscale x 1 x i8> %3, splat (i8 1)
ret <vscale x 1 x i8> %4
}