@@ -76,23 +76,22 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
7676
7777 for (int i = n - 1 ; i >= 0 ; --i) {
7878 multiIndex[i] = b.create <arith::RemSIOp>(loc, linearIndex, dimensions[i]);
79- if (i > 0 ) {
79+ if (i > 0 )
8080 linearIndex = b.create <arith::DivSIOp>(loc, linearIndex, dimensions[i]);
81- }
8281 }
8382
8483 return multiIndex;
8584}
8685
87- // Create operations converting a multi-dimensional index to a linear index
86+ // / Create operations converting a multi-dimensional index to a linear index.
8887Value multiToLinearIndex (Location loc, OpBuilder b, ValueRange multiIndex,
8988 ValueRange dimensions) {
9089
91- auto linearIndex = b.create <arith::ConstantIndexOp>(loc, 0 ). getResult ( );
92- auto stride = b.create <arith::ConstantIndexOp>(loc, 1 ). getResult ( );
90+ Value linearIndex = b.create <arith::ConstantIndexOp>(loc, 0 );
91+ Value stride = b.create <arith::ConstantIndexOp>(loc, 1 );
9392
9493 for (int i = multiIndex.size () - 1 ; i >= 0 ; --i) {
95- auto off = b.create <arith::MulIOp>(loc, multiIndex[i], stride);
94+ Value off = b.create <arith::MulIOp>(loc, multiIndex[i], stride);
9695 linearIndex = b.create <arith::AddIOp>(loc, linearIndex, off);
9796 stride = b.create <arith::MulIOp>(loc, stride, dimensions[i]);
9897 }
@@ -247,34 +246,32 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
247246};
248247
249248struct ConvertProcessMultiIndexOp
250- : public mlir::OpRewritePattern<mlir::mesh:: ProcessMultiIndexOp> {
251- using OpRewritePattern::OpRewritePattern ;
249+ : public OpConversionPattern< ProcessMultiIndexOp> {
250+ using OpConversionPattern::OpConversionPattern ;
252251
253- mlir:: LogicalResult
254- matchAndRewrite (mlir::mesh:: ProcessMultiIndexOp op,
255- mlir::PatternRewriter &rewriter) const override {
252+ LogicalResult
253+ matchAndRewrite (ProcessMultiIndexOp op, OpAdaptor adaptor ,
254+ ConversionPatternRewriter &rewriter) const override {
256255
257256 // Currently converts its linear index to a multi-dimensional index.
258257
259258 SymbolTableCollection symbolTableCollection;
260- auto loc = op.getLoc ();
259+ Location loc = op.getLoc ();
261260 auto meshOp = getMesh (op, symbolTableCollection);
262261 // For now we only support static mesh shapes
263- if (ShapedType::isDynamicShape (meshOp.getShape ())) {
264- return mlir::failure ();
265- }
262+ if (ShapedType::isDynamicShape (meshOp.getShape ()))
263+ return failure ();
266264
267265 SmallVector<Value> dims;
268266 llvm::transform (
269267 meshOp.getShape (), std::back_inserter (dims), [&](int64_t i) {
270268 return rewriter.create <arith::ConstantIndexOp>(loc, i).getResult ();
271269 });
272- auto rank =
273- rewriter.create <ProcessLinearIndexOp>(op.getLoc (), meshOp).getResult ();
270+ Value rank = rewriter.create <ProcessLinearIndexOp>(op.getLoc (), meshOp);
274271 auto mIdx = linearToMultiIndex (loc, rewriter, rank, dims);
275272
276273 // optionally extract subset of mesh axes
277- auto axes = op .getAxes ();
274+ auto axes = adaptor .getAxes ();
278275 if (!axes.empty ()) {
279276 SmallVector<Value> subIndex;
280277 for (auto axis : axes) {
@@ -319,44 +316,43 @@ class ConvertProcessLinearIndexOp
319316 .getRank ();
320317 rewriter.replaceOpWithNewOp <arith::IndexCastOp>(op, rewriter.getIndexType (),
321318 rank);
322- return mlir:: success ();
319+ return success ();
323320 }
324321};
325322
326323struct ConvertNeighborsLinearIndicesOp
327- : public mlir::OpRewritePattern<mlir::mesh:: NeighborsLinearIndicesOp> {
328- using OpRewritePattern::OpRewritePattern ;
324+ : public OpConversionPattern< NeighborsLinearIndicesOp> {
325+ using OpConversionPattern::OpConversionPattern ;
329326
330- mlir:: LogicalResult
331- matchAndRewrite (mlir::mesh:: NeighborsLinearIndicesOp op,
332- mlir::PatternRewriter &rewriter) const override {
327+ LogicalResult
328+ matchAndRewrite (NeighborsLinearIndicesOp op, OpAdaptor adaptor ,
329+ ConversionPatternRewriter &rewriter) const override {
333330
334331 // Computes the neighbors indices along a split axis by simply
335332 // adding/subtracting 1 to the current index in that dimension.
336333 // Assigns -1 if neighbor is out of bounds.
337334
338- auto axes = op .getSplitAxes ();
335+ auto axes = adaptor .getSplitAxes ();
339336 // For now only single axis sharding is supported
340- if (axes.size () != 1 ) {
341- return mlir::failure ();
342- }
337+ if (axes.size () != 1 )
338+ return failure ();
343339
344- auto loc = op.getLoc ();
340+ Location loc = op.getLoc ();
345341 SymbolTableCollection symbolTableCollection;
346342 auto meshOp = getMesh (op, symbolTableCollection);
347- auto mIdx = op .getDevice ();
343+ auto mIdx = adaptor .getDevice ();
348344 auto orgIdx = mIdx [axes[0 ]];
349345 SmallVector<Value> dims;
350346 llvm::transform (
351347 meshOp.getShape (), std::back_inserter (dims), [&](int64_t i) {
352348 return rewriter.create <arith::ConstantIndexOp>(loc, i).getResult ();
353349 });
354- auto dimSz = dims[axes[0 ]];
355- auto one = rewriter.create <arith::ConstantIndexOp>(loc, 1 ). getResult ( );
356- auto minus1 = rewriter.create <arith::ConstantIndexOp>(loc, -1 ). getResult ( );
357- auto atBorder = rewriter.create <arith::CmpIOp>(
350+ Value dimSz = dims[axes[0 ]];
351+ Value one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
352+ Value minus1 = rewriter.create <arith::ConstantIndexOp>(loc, -1 );
353+ Value atBorder = rewriter.create <arith::CmpIOp>(
358354 loc, arith::CmpIPredicate::sle, orgIdx,
359- rewriter.create <arith::ConstantIndexOp>(loc, 0 ). getResult () );
355+ rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
360356 auto down = rewriter.create <scf::IfOp>(
361357 loc, atBorder,
362358 [&](OpBuilder &builder, Location loc) {
@@ -598,23 +594,20 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
598594 // we need the actual shape to compute offsets and sizes
599595 for (auto i = 0 ; i < rank; ++i) {
600596 auto s = dstShape[i];
601- if (ShapedType::isDynamic (s)) {
597+ if (ShapedType::isDynamic (s))
602598 shape[i] = rewriter.create <memref::DimOp>(loc, array, s).getResult ();
603- } else {
599+ else
604600 shape[i] = rewriter.getIndexAttr (s);
605- }
606601
607602 if ((size_t )i < opSplitAxes.size () && !opSplitAxes[i].empty ()) {
608603 ++currHaloDim;
609604 // the offsets for lower dim sstarts after their down halo
610605 offsets[i] = haloSizes[currHaloDim * 2 ];
611606
612607 // prepare shape and offsets of highest dim's halo exchange
613- auto _haloSz =
614- rewriter
615- .create <arith::AddIOp>(loc, toValue (haloSizes[currHaloDim * 2 ]),
616- toValue (haloSizes[currHaloDim * 2 + 1 ]))
617- .getResult ();
608+ Value _haloSz = rewriter.create <arith::AddIOp>(
609+ loc, toValue (haloSizes[currHaloDim * 2 ]),
610+ toValue (haloSizes[currHaloDim * 2 + 1 ]));
618611 // the halo shape of lower dims exlude the halos
619612 dimSizes[i] =
620613 rewriter.create <arith::SubIOp>(loc, toValue (shape[i]), _haloSz)
@@ -625,9 +618,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
625618 }
626619
627620 auto tagAttr = rewriter.getI32IntegerAttr (91 ); // we just pick something
628- auto tag = rewriter.create <::mlir:: arith::ConstantOp>(loc, tagAttr);
621+ auto tag = rewriter.create <arith::ConstantOp>(loc, tagAttr);
629622 auto zeroAttr = rewriter.getI32IntegerAttr (0 ); // for detecting v<0
630- auto zero = rewriter.create <::mlir:: arith::ConstantOp>(loc, zeroAttr);
623+ auto zero = rewriter.create <arith::ConstantOp>(loc, zeroAttr);
631624
632625 SmallVector<Type> indexResultTypes (meshOp.getShape ().size (),
633626 rewriter.getIndexType ());
@@ -637,9 +630,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
637630 // traverse all split axes from high to low dim
638631 for (ssize_t dim = opSplitAxes.size () - 1 ; dim >= 0 ; --dim) {
639632 auto splitAxes = opSplitAxes[dim];
640- if (splitAxes.empty ()) {
633+ if (splitAxes.empty ())
641634 continue ;
642- }
643635 assert (currHaloDim >= 0 && (size_t )currHaloDim < haloSizes.size () / 2 );
644636 // Get the linearized ids of the neighbors (down and up) for the
645637 // given split
0 commit comments