@@ -1355,12 +1355,9 @@ struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
13551355 // All lanes extract the scalar.
13561356 if (is0dOrVec1Extract) {
13571357 Value newExtract;
1358- if (extractSrcType.getRank () == 1 ) {
1359- newExtract = rewriter.create <vector::ExtractOp>(loc, distributedVec, 0 );
1360- } else {
1361- newExtract = rewriter.create <vector::ExtractOp>(loc, distributedVec,
1362- ArrayRef<int64_t >{});
1363- }
1358+ SmallVector<int64_t > indices (extractSrcType.getRank (), 0 );
1359+ newExtract =
1360+ rewriter.create <vector::ExtractOp>(loc, distributedVec, indices);
13641361 rewriter.replaceAllUsesWith (newWarpOp->getResult (operandNumber),
13651362 newExtract);
13661363 return success ();
@@ -1408,14 +1405,13 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14081405 if (!operand)
14091406 return failure ();
14101407 auto extractOp = operand->get ().getDefiningOp <vector::ExtractElementOp>();
1411- rewriter. setInsertionPoint (extractOp) ;
1408+ SmallVector<OpFoldResult> indices ;
14121409 if (auto pos = extractOp.getPosition ()) {
1413- rewriter.replaceOpWithNewOp <vector::ExtractOp>(
1414- extractOp, extractOp.getVector (), pos);
1415- } else {
1416- rewriter.replaceOpWithNewOp <vector::ExtractOp>(
1417- extractOp, extractOp.getVector (), ArrayRef<int64_t >{});
1410+ indices.push_back (pos);
14181411 }
1412+ rewriter.setInsertionPoint (extractOp);
1413+ rewriter.replaceOpWithNewOp <vector::ExtractOp>(
1414+ extractOp, extractOp.getVector (), indices);
14191415 return success ();
14201416 }
14211417};
@@ -1472,13 +1468,12 @@ struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
14721468 // This condition is always true for 0-d vectors.
14731469 if (vecType == distrType) {
14741470 Value newInsert;
1471+ SmallVector<OpFoldResult> indices;
14751472 if (pos) {
1476- newInsert = rewriter.create <vector::InsertOp>(loc, newSource,
1477- distributedVec, pos);
1478- } else {
1479- newInsert = rewriter.create <vector::InsertOp>(
1480- loc, newSource, distributedVec, ArrayRef<int64_t >{});
1473+ indices.push_back (pos);
14811474 }
1475+ newInsert = rewriter.create <vector::InsertOp>(loc, newSource,
1476+ distributedVec, indices);
14821477 // Broadcast: Simply move the vector.insert op out.
14831478 rewriter.replaceAllUsesWith (newWarpOp->getResult (operandNumber),
14841479 newInsert);
@@ -1640,15 +1635,13 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
16401635 if (!operand)
16411636 return failure ();
16421637 auto insertOp = operand->get ().getDefiningOp <vector::InsertElementOp>();
1643- rewriter. setInsertionPoint (insertOp) ;
1638+ SmallVector<OpFoldResult> indices ;
16441639 if (auto pos = insertOp.getPosition ()) {
1645- rewriter.replaceOpWithNewOp <vector::InsertOp>(
1646- insertOp, insertOp.getSource (), insertOp.getDest (), pos);
1647- } else {
1648- rewriter.replaceOpWithNewOp <vector::InsertOp>(
1649- insertOp, insertOp.getSource (), insertOp.getDest (),
1650- ArrayRef<int64_t >{});
1640+ indices.push_back (pos);
16511641 }
1642+ rewriter.setInsertionPoint (insertOp);
1643+ rewriter.replaceOpWithNewOp <vector::InsertOp>(
1644+ insertOp, insertOp.getSource (), insertOp.getDest (), indices);
16521645 return success ();
16531646 }
16541647};
0 commit comments