@@ -308,12 +308,6 @@ struct TwoDimMultiReductionToElementWise
308308
309309 LogicalResult matchAndRewrite (vector::MultiDimReductionOp multiReductionOp,
310310 PatternRewriter &rewriter) const override {
311- auto maskableOp =
312- cast<vector::MaskableOpInterface>(multiReductionOp.getOperation ());
313- if (maskableOp.isMasked ())
314- // TODO: Support masking.
315- return failure ();
316-
317311 auto srcRank = multiReductionOp.getSourceVectorType ().getRank ();
318312 // Rank-2 ["parallel", "reduce"] or bail.
319313 if (srcRank != 2 )
@@ -330,15 +324,33 @@ struct TwoDimMultiReductionToElementWise
330324 if (!elementType.isIntOrIndexOrFloat ())
331325 return failure ();
332326
327+ OpBuilder::InsertionGuard guard (rewriter);
328+ auto maskableOp =
329+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation ());
330+ Operation *rootOp;
331+ Value mask = nullptr ;
332+ if (maskableOp.isMasked ()) {
333+ rewriter.setInsertionPoint (maskableOp.getMaskingOp ());
334+ rootOp = maskableOp.getMaskingOp ();
335+ mask = maskableOp.getMaskingOp ().getMask ();
336+ } else {
337+ rootOp = multiReductionOp;
338+ }
339+
333340 Value result = multiReductionOp.getAcc ();
334341 for (int64_t i = 0 ; i < srcShape[0 ]; i++) {
335342 auto operand = rewriter.create <vector::ExtractOp>(
336343 loc, multiReductionOp.getSource (), i);
337- result = makeArithReduction (rewriter, loc, multiReductionOp.getKind (),
338- operand, result);
344+ Value extractMask = nullptr ;
345+ if (mask) {
346+ extractMask = rewriter.create <vector::ExtractOp>(loc, mask, i);
347+ }
348+ result =
349+ makeArithReduction (rewriter, loc, multiReductionOp.getKind (), operand,
350+ result, /* fastmath=*/ nullptr , extractMask);
339351 }
340352
341- rewriter.replaceOp (multiReductionOp , result);
353+ rewriter.replaceOp (rootOp , result);
342354 return success ();
343355 }
344356};
0 commit comments