Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISD.def
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ HANDLE_NODETYPE(BR_IF)
HANDLE_NODETYPE(BR_TABLE)
HANDLE_NODETYPE(DOT)
HANDLE_NODETYPE(EXT_ADD_PAIRWISE_U)
HANDLE_NODETYPE(EXT_ADD_PAIRWISE_S)
HANDLE_NODETYPE(SHUFFLE)
HANDLE_NODETYPE(SWIZZLE)
HANDLE_NODETYPE(VEC_SHL)
Expand Down
180 changes: 101 additions & 79 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,24 +422,30 @@ bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
return true;

EVT VT = EVT::getEVT(I->getType());
if (VT.getSizeInBits() > 128)
return true;

auto Op1 = I->getOperand(1);

if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
if (InstructionOpcodeToISD(InputInst->getOpcode()) != ISD::MUL)
return true;

if (isa<Instruction>(InputInst->getOperand(0)) &&
isa<Instruction>(InputInst->getOperand(1))) {
// dot only supports signed inputs but also support lowering unsigned.
if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
cast<Instruction>(InputInst->getOperand(1))->getOpcode())
return true;

EVT Op1VT = EVT::getEVT(Op1->getType());
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
((VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()) ||
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
return false;
unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
if (Opcode == ISD::MUL) {
if (isa<Instruction>(InputInst->getOperand(0)) &&
isa<Instruction>(InputInst->getOperand(1))) {
// dot only supports signed inputs but also support lowering unsigned.
if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
cast<Instruction>(InputInst->getOperand(1))->getOpcode())
return true;

EVT Op1VT = EVT::getEVT(Op1->getType());
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
((VT.getVectorElementCount() * 2 ==
Op1VT.getVectorElementCount()) ||
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
return false;
}
} else if (ISD::isExtOpcode(Opcode)) {
return false;
}
}
return true;
Expand Down Expand Up @@ -2117,77 +2123,93 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {

assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
SDLoc DL(N);
SDValue Mul = N->getOperand(2);
assert(Mul->getOpcode() == ISD::MUL && "expected mul input");

SDValue ExtendLHS = Mul->getOperand(0);
SDValue ExtendRHS = Mul->getOperand(1);
assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
"expected widening mul");
assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
"expected mul to use the same extend for both operands");

SDValue ExtendInLHS = ExtendLHS->getOperand(0);
SDValue ExtendInRHS = ExtendRHS->getOperand(0);
bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;

if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
if (IsSigned) {
// i32x4.dot_i16x8_s
SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
ExtendInLHS, ExtendInRHS);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
}

unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
SDValue Input = N->getOperand(2);
if (Input->getOpcode() == ISD::MUL) {
SDValue ExtendLHS = Input->getOperand(0);
SDValue ExtendRHS = Input->getOperand(1);
assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
"expected widening mul or add");
assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
"expected binop to use the same extend for both operands");

SDValue ExtendInLHS = ExtendLHS->getOperand(0);
SDValue ExtendInRHS = ExtendRHS->getOperand(0);
bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
unsigned LowOpc =
IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
: WebAssemblyISD::EXTEND_HIGH_U;
SDValue LowLHS;
SDValue LowRHS;
SDValue HighLHS;
SDValue HighRHS;

auto AssignInputs = [&](MVT VT) {
LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS);
LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS);
HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
};

// (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInLHS);
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInRHS);
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInLHS);
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInRHS);
if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
if (IsSigned) {
// i32x4.dot_i16x8_s
SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
ExtendInLHS, ExtendInRHS);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
}

SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
// (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
MVT VT = MVT::v4i32;
AssignInputs(VT);
SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
} else {
assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
"expected v16i8 input types");
AssignInputs(MVT::v8i16);
// Lower to a wider tree, using twice the operations compared to above.
if (IsSigned) {
// Use two dots
SDValue DotLHS =
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
SDValue DotRHS =
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}

SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);

SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
MVT::v4i32, MulLow);
SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
MVT::v4i32, MulHigh);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}
} else {
assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
"expected v16i8 input types");
// Lower to a wider tree, using twice the operations compared to above.
if (IsSigned) {
// Use two dots
unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S;
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S;
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
SDValue DotLHS =
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
SDValue DotRHS =
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
// Accumulate the input using extadd_pairwise.
assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
: WebAssemblyISD::EXT_ADD_PAIRWISE_U;
SDValue ExtendIn = Input->getOperand(0);
if (ExtendIn->getValueType(0) == MVT::v8i16) {
SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}

unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);

SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);

SDValue AddLow =
DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, MVT::v4i32, MulLow);
SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
MVT::v4i32, MulHigh);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
"expected v16i8 input types");
SDValue Add =
DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}
}
Expand Down
9 changes: 7 additions & 2 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -1454,12 +1454,13 @@ def : Pat<(t1.vt (bitconvert (t2.vt V128:$v))), (t1.vt V128:$v)>;

// Extended pairwise addition
def extadd_pairwise_u : SDNode<"WebAssemblyISD::EXT_ADD_PAIRWISE_U", extend_t>;
def extadd_pairwise_s : SDNode<"WebAssemblyISD::EXT_ADD_PAIRWISE_S", extend_t>;

defm "" : SIMDConvert<I16x8, I8x16, int_wasm_extadd_pairwise_signed,
defm "" : SIMDConvert<I16x8, I8x16, extadd_pairwise_s,
"extadd_pairwise_i8x16_s", 0x7c>;
defm "" : SIMDConvert<I16x8, I8x16, extadd_pairwise_u,
"extadd_pairwise_i8x16_u", 0x7d>;
defm "" : SIMDConvert<I32x4, I16x8, int_wasm_extadd_pairwise_signed,
defm "" : SIMDConvert<I32x4, I16x8, extadd_pairwise_s,
"extadd_pairwise_i16x8_s", 0x7e>;
defm "" : SIMDConvert<I32x4, I16x8, extadd_pairwise_u,
"extadd_pairwise_i16x8_u", 0x7f>;
Expand All @@ -1468,6 +1469,10 @@ def : Pat<(v4i32 (int_wasm_extadd_pairwise_unsigned (v8i16 V128:$in))),
(extadd_pairwise_u_I32x4 V128:$in)>;
def : Pat<(v8i16 (int_wasm_extadd_pairwise_unsigned (v16i8 V128:$in))),
(extadd_pairwise_u_I16x8 V128:$in)>;
def : Pat<(v4i32 (int_wasm_extadd_pairwise_signed (v8i16 V128:$in))),
(extadd_pairwise_s_I32x4 V128:$in)>;
def : Pat<(v8i16 (int_wasm_extadd_pairwise_signed (v16i8 V128:$in))),
(extadd_pairwise_s_I16x8 V128:$in)>;

// f64x2 <-> f32x4 conversions
def demote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
Expand Down
27 changes: 18 additions & 9 deletions llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,31 +316,40 @@ InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
if (CostKind != TTI::TCK_RecipThroughput)
return Invalid;

InstructionCost Cost(TTI::TCC_Basic);
if (Opcode != Instruction::Add)
return Invalid;

EVT AccumEVT = EVT::getEVT(AccumType);
// TODO: Add i64 accumulator.
if (AccumEVT != MVT::i32)
return Invalid;

// Possible options:
// - i16x8.extadd_pairwise_i8x16_sx
// - i32x4.extadd_pairwise_i16x8_sx
// - i32x4.dot_i16x8_s
// Only try to support dot, for now.

if (Opcode != Instruction::Add)
EVT InputEVT = EVT::getEVT(InputTypeA);
if (!((InputEVT == MVT::i16 && VF.getFixedValue() == 8) ||
(InputEVT == MVT::i8 && VF.getFixedValue() == 16))) {
return Invalid;
}

if (!BinOp || *BinOp != Instruction::Mul)
if (OpAExtend == TTI::PR_None)
return Invalid;

if (InputTypeA != InputTypeB)
return Invalid;
InstructionCost Cost(TTI::TCC_Basic);
if (!BinOp)
return Cost;

if (OpAExtend != OpBExtend)
return Invalid;

EVT InputEVT = EVT::getEVT(InputTypeA);
EVT AccumEVT = EVT::getEVT(AccumType);
if (*BinOp != Instruction::Mul)
return Invalid;

// TODO: Add i64 accumulator.
if (AccumEVT != MVT::i32)
if (InputTypeA != InputTypeB)
return Invalid;

// Signed inputs can lower to dot
Expand Down
Loading