@@ -753,7 +753,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
753753 setOperationAction ({ISD::LOAD, ISD::STORE}, {MVT::i128 , MVT::f128 }, Custom);
754754 for (MVT VT : MVT::fixedlen_vector_valuetypes ())
755755 if (!isTypeLegal (VT) && VT.getStoreSizeInBits () <= 256 )
756- setOperationAction ({ISD::STORE, ISD::LOAD}, VT, Custom);
756+ setOperationAction ({ISD::STORE, ISD::LOAD, ISD::MSTORE }, VT, Custom);
757757
758758 // Custom legalization for LDU intrinsics.
759759 // TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -2869,6 +2869,87 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
28692869 return Or;
28702870}
28712871
2872+ static SDValue lowerMSTORE (SDValue Op, SelectionDAG &DAG) {
2873+ SDNode *N = Op.getNode ();
2874+
2875+ SDValue Chain = N->getOperand (0 );
2876+ SDValue Val = N->getOperand (1 );
2877+ SDValue BasePtr = N->getOperand (2 );
2878+ SDValue Offset = N->getOperand (3 );
2879+ SDValue Mask = N->getOperand (4 );
2880+
2881+ SDLoc DL (N);
2882+ EVT ValVT = Val.getValueType ();
2883+ MemSDNode *MemSD = cast<MemSDNode>(N);
2884+ assert (ValVT.isVector () && " Masked vector store must have vector type" );
2885+ assert (MemSD->getAlign () >= DAG.getEVTAlign (ValVT) &&
2886+ " Unexpected alignment for masked store" );
2887+
2888+ unsigned Opcode = 0 ;
2889+ switch (ValVT.getSimpleVT ().SimpleTy ) {
2890+ default :
2891+ llvm_unreachable (" Unexpected masked vector store type" );
2892+ case MVT::v4i64:
2893+ case MVT::v4f64: {
2894+ Opcode = NVPTXISD::StoreV4;
2895+ break ;
2896+ }
2897+ case MVT::v8i32:
2898+ case MVT::v8f32: {
2899+ Opcode = NVPTXISD::StoreV8;
2900+ break ;
2901+ }
2902+ }
2903+
2904+ SmallVector<SDValue, 8 > Ops;
2905+
2906+ // Construct the new SDNode. First operand is the chain.
2907+ Ops.push_back (Chain);
2908+
2909+ // The next N operands are the values to store. Encode the mask into the
2910+ // values using the sentinel register 0 to represent a masked-off element.
2911+ assert (Mask.getValueType ().isVector () &&
2912+ Mask.getValueType ().getVectorElementType () == MVT::i1 &&
2913+ " Mask must be a vector of i1" );
2914+ assert (Mask.getOpcode () == ISD::BUILD_VECTOR &&
2915+ " Mask expected to be a BUILD_VECTOR" );
2916+ assert (Mask.getValueType ().getVectorNumElements () ==
2917+ ValVT.getVectorNumElements () &&
2918+ " Mask size must be the same as the vector size" );
2919+ for (unsigned I : llvm::seq (ValVT.getVectorNumElements ())) {
2920+ assert (isa<ConstantSDNode>(Mask.getOperand (I)) &&
2921+ " Mask elements must be constants" );
2922+ if (Mask->getConstantOperandVal (I) == 0 ) {
2923+ // Append a sentinel register 0 to the Ops vector to represent a masked
2924+ // off element, this will be handled in tablegen
2925+ Ops.push_back (DAG.getRegister (MCRegister::NoRegister,
2926+ ValVT.getVectorElementType ()));
2927+ } else {
2928+ // Extract the element from the vector to store
2929+ SDValue ExtVal =
2930+ DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType (),
2931+ Val, DAG.getIntPtrConstant (I, DL));
2932+ Ops.push_back (ExtVal);
2933+ }
2934+ }
2935+
2936+ // Next, the pointer operand.
2937+ Ops.push_back (BasePtr);
2938+
2939+ // Finally, the offset operand. We expect this to always be undef, and it will
2940+ // be ignored in lowering, but to mirror the handling of the other vector
2941+ // store instructions we include it in the new SDNode.
2942+ assert (Offset.getOpcode () == ISD::UNDEF &&
2943+ " Offset operand expected to be undef" );
2944+ Ops.push_back (Offset);
2945+
2946+ SDValue NewSt =
2947+ DAG.getMemIntrinsicNode (Opcode, DL, DAG.getVTList (MVT::Other), Ops,
2948+ MemSD->getMemoryVT (), MemSD->getMemOperand ());
2949+
2950+ return NewSt;
2951+ }
2952+
28722953SDValue
28732954NVPTXTargetLowering::LowerOperation (SDValue Op, SelectionDAG &DAG) const {
28742955 switch (Op.getOpcode ()) {
@@ -2905,6 +2986,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
29052986 return LowerVECREDUCE (Op, DAG);
29062987 case ISD::STORE:
29072988 return LowerSTORE (Op, DAG);
2989+ case ISD::MSTORE: {
2990+ assert (STI.has256BitVectorLoadStore (
2991+ cast<MemSDNode>(Op.getNode ())->getAddressSpace ()) &&
2992+ " Masked store vector not supported on subtarget." );
2993+ return lowerMSTORE (Op, DAG);
2994+ }
29082995 case ISD::LOAD:
29092996 return LowerLOAD (Op, DAG);
29102997 case ISD::SHL_PARTS:
0 commit comments