@@ -3453,7 +3453,8 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
34533453}
34543454
34553455static std::pair<MemSDNode *, uint32_t >
3456- convertMLOADToLoadWithUsedBytesMask (MemSDNode *N, SelectionDAG &DAG) {
3456+ convertMLOADToLoadWithUsedBytesMask (MemSDNode *N, SelectionDAG &DAG,
3457+ const NVPTXSubtarget &STI) {
34573458 SDValue Chain = N->getOperand (0 );
34583459 SDValue BasePtr = N->getOperand (1 );
34593460 SDValue Mask = N->getOperand (3 );
@@ -3495,6 +3496,11 @@ convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
34953496 MemSDNode *NewLD = cast<MemSDNode>(
34963497 DAG.getLoad (ResVT, DL, Chain, BasePtr, N->getMemOperand ()).getNode ());
34973498
3499+ // If our subtarget does not support the used bytes mask pragma, "drop" the
3500+ // mask by setting it to UINT32_MAX
3501+ if (!STI.hasUsedBytesMaskPragma ())
3502+ UsedBytesMask = UINT32_MAX;
3503+
34983504 return {NewLD, UsedBytesMask};
34993505}
35003506
@@ -3531,7 +3537,8 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
35313537 // If we have a masked load, convert it to a normal load now
35323538 std::optional<uint32_t > UsedBytesMask = std::nullopt ;
35333539 if (LD->getOpcode () == ISD::MLOAD)
3534- std::tie (LD, UsedBytesMask) = convertMLOADToLoadWithUsedBytesMask (LD, DAG);
3540+ std::tie (LD, UsedBytesMask) =
3541+ convertMLOADToLoadWithUsedBytesMask (LD, DAG, STI);
35353542
35363543 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
35373544 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
@@ -3667,8 +3674,8 @@ SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
36673674 // them here.
36683675 EVT VT = Op.getValueType ();
36693676 if (NVPTX::isPackedVectorTy (VT)) {
3670- auto Result =
3671- convertMLOADToLoadWithUsedBytesMask ( cast<MemSDNode>(Op.getNode ()), DAG);
3677+ auto Result = convertMLOADToLoadWithUsedBytesMask (
3678+ cast<MemSDNode>(Op.getNode ()), DAG, STI );
36723679 MemSDNode *LD = std::get<0 >(Result);
36733680 uint32_t UsedBytesMask = std::get<1 >(Result);
36743681
0 commit comments