@@ -769,7 +769,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
769769 setOperationAction ({ISD::LOAD, ISD::STORE}, {MVT::i128 , MVT::f128 }, Custom);
770770 for (MVT VT : MVT::fixedlen_vector_valuetypes ())
771771 if (!isTypeLegal (VT) && VT.getStoreSizeInBits () <= 256 )
772- setOperationAction ({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
772+ setOperationAction ({ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
773+ Custom);
773774
774775 // Custom legalization for LDU intrinsics.
775776 // TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -1130,6 +1131,7 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
11301131 MAKE_CASE (NVPTXISD::LoadV2)
11311132 MAKE_CASE (NVPTXISD::LoadV4)
11321133 MAKE_CASE (NVPTXISD::LoadV8)
1134+ MAKE_CASE (NVPTXISD::MLoadV1)
11331135 MAKE_CASE (NVPTXISD::LDUV2)
11341136 MAKE_CASE (NVPTXISD::LDUV4)
11351137 MAKE_CASE (NVPTXISD::StoreV2)
@@ -3306,6 +3308,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
33063308 }
33073309 case ISD::LOAD:
33083310 return LowerLOAD (Op, DAG);
3311+ case ISD::MLOAD:
3312+ return LowerMLOAD (Op, DAG);
33093313 case ISD::SHL_PARTS:
33103314 return LowerShiftLeftParts (Op, DAG);
33113315 case ISD::SRA_PARTS:
@@ -3497,10 +3501,58 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
34973501 MachinePointerInfo (SV));
34983502}
34993503
3504+ static std::tuple<MemSDNode *, uint32_t >
3505+ convertMLOADToLoadWithUsedBytesMask (MemSDNode *N, SelectionDAG &DAG) {
3506+ SDValue Chain = N->getOperand (0 );
3507+ SDValue BasePtr = N->getOperand (1 );
3508+ SDValue Mask = N->getOperand (3 );
3509+ SDValue Passthru = N->getOperand (4 );
3510+
3511+ SDLoc DL (N);
3512+ EVT ResVT = N->getValueType (0 );
3513+ assert (ResVT.isVector () && " Masked vector load must have vector type" );
3514+ // While we only expect poison passthru vectors as an input to the backend,
3515+ // when the legalization framework splits a poison vector in half, it creates
3516+ // two undef vectors, so we can technically expect those too.
3517+ assert ((Passthru.getOpcode () == ISD::POISON ||
3518+ Passthru.getOpcode () == ISD::UNDEF) &&
3519+ " Passthru operand expected to be poison or undef" );
3520+
3521+ // Extract the mask and convert it to a uint32_t representing the used bytes
3522+ // of the entire vector load
3523+ uint32_t UsedBytesMask = 0 ;
3524+ uint32_t ElementSizeInBits = ResVT.getVectorElementType ().getSizeInBits ();
3525+ assert (ElementSizeInBits % 8 == 0 && " Unexpected element size" );
3526+ uint32_t ElementSizeInBytes = ElementSizeInBits / 8 ;
3527+ uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u ;
3528+
3529+ for (unsigned I :
3530+ llvm::reverse (llvm::seq<unsigned >(0 , ResVT.getVectorNumElements ()))) {
3531+ assert (isa<ConstantSDNode>(Mask.getOperand (I)) &&
3532+ " Mask elements must be constants" );
3533+ // We technically only want to do this shift for every iteration *but* the
3534+ // first, but in the first iteration NewMask is 0, so this shift is a
3535+ // no-op.
3536+ UsedBytesMask <<= ElementSizeInBytes;
3537+
3538+ if (Mask->getConstantOperandVal (I) != 0 )
3539+ UsedBytesMask |= ElementMask;
3540+ }
3541+
3542+ assert (UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
3543+ " Unexpected masked load with elements masked all on or all off" );
3544+
3545+ // Create a new load sd node to be handled normally by ReplaceLoadVector.
3546+ MemSDNode *NewLD = cast<MemSDNode>(
3547+ DAG.getLoad (ResVT, DL, Chain, BasePtr, N->getMemOperand ()).getNode ());
3548+
3549+ return {NewLD, UsedBytesMask};
3550+ }
3551+
35003552// / replaceLoadVector - Convert vector loads into multi-output scalar loads.
35013553static std::optional<std::pair<SDValue, SDValue>>
35023554replaceLoadVector (SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
3503- LoadSDNode *LD = cast<LoadSDNode >(N);
3555+ MemSDNode *LD = cast<MemSDNode >(N);
35043556 const EVT ResVT = LD->getValueType (0 );
35053557 const EVT MemVT = LD->getMemoryVT ();
35063558
@@ -3527,6 +3579,14 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
35273579 return std::nullopt ;
35283580 }
35293581
3582+ // If we have a masked load, convert it to a normal load now
3583+ std::optional<uint32_t > UsedBytesMask = std::nullopt ;
3584+ if (LD->getOpcode () == ISD::MLOAD) {
3585+ auto Result = convertMLOADToLoadWithUsedBytesMask (LD, DAG);
3586+ LD = std::get<0 >(Result);
3587+ UsedBytesMask = std::get<1 >(Result);
3588+ }
3589+
35303590 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
35313591 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
35323592 // loaded type to i16 and propagate the "real" type as the memory type.
@@ -3555,9 +3615,13 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
35553615 // Copy regular operands
35563616 SmallVector<SDValue, 8 > OtherOps (LD->ops ());
35573617
3618+ OtherOps.push_back (
3619+ DAG.getConstant (UsedBytesMask.value_or (UINT32_MAX), DL, MVT::i32 ));
3620+
35583621 // The select routine does not have access to the LoadSDNode instance, so
35593622 // pass along the extension information
3560- OtherOps.push_back (DAG.getIntPtrConstant (LD->getExtensionType (), DL));
3623+ OtherOps.push_back (
3624+ DAG.getIntPtrConstant (cast<LoadSDNode>(LD)->getExtensionType (), DL));
35613625
35623626 SDValue NewLD = DAG.getMemIntrinsicNode (Opcode, DL, LdResVTs, OtherOps, MemVT,
35633627 LD->getMemOperand ());
@@ -3645,6 +3709,43 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
36453709 llvm_unreachable (" Unexpected custom lowering for load" );
36463710}
36473711
3712+ SDValue NVPTXTargetLowering::LowerMLOAD (SDValue Op, SelectionDAG &DAG) const {
3713+ // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3714+ // masked loads of these types and have to handle them here.
3715+ // v2f32 also needs to be handled here if the subtarget has f32x2
3716+ // instructions, making it legal.
3717+ //
3718+ // Note: misaligned masked loads should never reach this point
3719+ // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
3720+ // will validate alignment. Therefore, we do not need to special case handle
3721+ // them here.
3722+ EVT VT = Op.getValueType ();
3723+ if (NVPTX::isPackedVectorTy (VT) &&
3724+ (VT != MVT::v2f32 || STI.hasF32x2Instructions ())) {
3725+ auto Result =
3726+ convertMLOADToLoadWithUsedBytesMask (cast<MemSDNode>(Op.getNode ()), DAG);
3727+ MemSDNode *LD = std::get<0 >(Result);
3728+ uint32_t UsedBytesMask = std::get<1 >(Result);
3729+
3730+ SDLoc DL (LD);
3731+
3732+ // Copy regular operands
3733+ SmallVector<SDValue, 8 > OtherOps (LD->ops ());
3734+
3735+ OtherOps.push_back (DAG.getConstant (UsedBytesMask, DL, MVT::i32 ));
3736+
3737+ // The select routine does not have access to the LoadSDNode instance, so
3738+ // pass along the extension information
3739+ OtherOps.push_back (
3740+ DAG.getIntPtrConstant (cast<LoadSDNode>(LD)->getExtensionType (), DL));
3741+ SDValue NewLD = DAG.getMemIntrinsicNode (
3742+ NVPTXISD::MLoadV1, DL, LD->getVTList (), OtherOps, LD->getMemoryVT (),
3743+ LD->getMemOperand ());
3744+ return NewLD;
3745+ }
3746+ return SDValue ();
3747+ }
3748+
36483749static SDValue lowerSTOREVector (SDValue Op, SelectionDAG &DAG,
36493750 const NVPTXSubtarget &STI) {
36503751 MemSDNode *N = cast<MemSDNode>(Op.getNode ());
@@ -5555,9 +5656,13 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
55555656 // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
55565657 // here.
55575658 Opcode = NVPTXISD::LoadV2;
5659+ // append a "full" used bytes mask operand right before the extension type
5660+ // operand, signifying that all bytes are used.
5661+ Operands.push_back (DCI.DAG .getConstant (UINT32_MAX, DL, MVT::i32 ));
55585662 Operands.push_back (DCI.DAG .getIntPtrConstant (
55595663 cast<LoadSDNode>(LD)->getExtensionType (), DL));
55605664 break ;
5665+ // TODO do we need to support MLoadV1 here?
55615666 case NVPTXISD::LoadV2:
55625667 OldNumOutputs = 2 ;
55635668 Opcode = NVPTXISD::LoadV4;
@@ -6793,6 +6898,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
67936898 ReplaceBITCAST (N, DAG, Results);
67946899 return ;
67956900 case ISD::LOAD:
6901+ case ISD::MLOAD:
67966902 replaceLoadVector (N, DAG, Results, STI);
67976903 return ;
67986904 case ISD::INTRINSIC_W_CHAIN:
0 commit comments