Skip to content

Commit 9e6a31f

Browse files
authored
[WebAssembly] vf32 to vi8, vi16 lowering (#164644)
Avoid scalarizing the conversion and use trunc_sat and narrow instead.
1 parent 9f5811e commit 9e6a31f

File tree

3 files changed

+869
-643
lines changed

3 files changed

+869
-643
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
216216
// Combine fp_to_{s,u}int_sat or fp_round of concat_vectors or vice versa
217217
// into conversion ops
218218
setTargetDAGCombine({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT,
219-
ISD::FP_ROUND, ISD::CONCAT_VECTORS});
219+
ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_ROUND,
220+
ISD::CONCAT_VECTORS});
220221

221222
setTargetDAGCombine(ISD::TRUNCATE);
222223

@@ -3580,6 +3581,64 @@ static SDValue performMulCombine(SDNode *N,
35803581
}
35813582
}
35823583

3584+
SDValue DoubleVectorWidth(SDValue In, unsigned RequiredNumElems,
3585+
SelectionDAG &DAG) {
3586+
SDLoc DL(In);
3587+
LLVMContext &Ctx = *DAG.getContext();
3588+
EVT InVT = In.getValueType();
3589+
unsigned NumElems = InVT.getVectorNumElements() * 2;
3590+
EVT OutVT = EVT::getVectorVT(Ctx, InVT.getVectorElementType(), NumElems);
3591+
SDValue Concat =
3592+
DAG.getNode(ISD::CONCAT_VECTORS, DL, OutVT, In, DAG.getPOISON(InVT));
3593+
if (NumElems < RequiredNumElems) {
3594+
return DoubleVectorWidth(Concat, RequiredNumElems, DAG);
3595+
}
3596+
return Concat;
3597+
}
3598+
3599+
SDValue performConvertFPCombine(SDNode *N, SelectionDAG &DAG) {
3600+
EVT OutVT = N->getValueType(0);
3601+
if (!OutVT.isVector())
3602+
return SDValue();
3603+
3604+
EVT OutElTy = OutVT.getVectorElementType();
3605+
if (OutElTy != MVT::i8 && OutElTy != MVT::i16)
3606+
return SDValue();
3607+
3608+
unsigned NumElems = OutVT.getVectorNumElements();
3609+
if (!isPowerOf2_32(NumElems))
3610+
return SDValue();
3611+
3612+
EVT FPVT = N->getOperand(0)->getValueType(0);
3613+
if (FPVT.getVectorElementType() != MVT::f32)
3614+
return SDValue();
3615+
3616+
SDLoc DL(N);
3617+
3618+
// First, convert to i32.
3619+
LLVMContext &Ctx = *DAG.getContext();
3620+
EVT IntVT = EVT::getVectorVT(Ctx, MVT::i32, NumElems);
3621+
SDValue ToInt = DAG.getNode(N->getOpcode(), DL, IntVT, N->getOperand(0));
3622+
APInt Mask = APInt::getLowBitsSet(IntVT.getScalarSizeInBits(),
3623+
OutVT.getScalarSizeInBits());
3624+
// Mask out the top MSBs.
3625+
SDValue Masked =
3626+
DAG.getNode(ISD::AND, DL, IntVT, ToInt, DAG.getConstant(Mask, DL, IntVT));
3627+
3628+
if (OutVT.getSizeInBits() < 128) {
3629+
// Create a wide enough vector that we can use narrow.
3630+
EVT NarrowedVT = OutElTy == MVT::i8 ? MVT::v16i8 : MVT::v8i16;
3631+
unsigned NumRequiredElems = NarrowedVT.getVectorNumElements();
3632+
SDValue WideVector = DoubleVectorWidth(Masked, NumRequiredElems, DAG);
3633+
SDValue Trunc = truncateVectorWithNARROW(NarrowedVT, WideVector, DL, DAG);
3634+
return DAG.getBitcast(
3635+
OutVT, extractSubVector(Trunc, 0, DAG, DL, OutVT.getSizeInBits()));
3636+
} else {
3637+
return truncateVectorWithNARROW(OutVT, Masked, DL, DAG);
3638+
}
3639+
return SDValue();
3640+
}
3641+
35833642
SDValue
35843643
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
35853644
DAGCombinerInfo &DCI) const {
@@ -3606,6 +3665,9 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
36063665
case ISD::FP_ROUND:
36073666
case ISD::CONCAT_VECTORS:
36083667
return performVectorTruncZeroCombine(N, DCI);
3668+
case ISD::FP_TO_SINT:
3669+
case ISD::FP_TO_UINT:
3670+
return performConvertFPCombine(N, DCI.DAG);
36093671
case ISD::TRUNCATE:
36103672
return performTruncateCombine(N, DCI);
36113673
case ISD::INTRINSIC_WO_CHAIN:

0 commit comments

Comments
 (0)