@@ -45,7 +45,8 @@ using namespace mlir;
4545using namespace mesh ;
4646
4747namespace {
48- // / Convert vec of OpFoldResults (ints) into vector of Values.
48+ // / Converts a vector of OpFoldResults (ints) into vector of Values of the
49+ // / provided type.
4950static SmallVector<Value> getMixedAsValues (OpBuilder b, const Location &loc,
5051 llvm::ArrayRef<int64_t > statics,
5152 ValueRange dynamics,
@@ -55,14 +56,15 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
5556 Type i64 = b.getI64Type ();
5657 if (!type)
5758 type = i64 ;
58- assert (i64 == type || b.getIndexType () == type);
59+ assert ((i64 == type || b.getIndexType () == type) &&
60+ " expected an i64 or an intex type" );
5961 for (auto s : statics) {
60- values. emplace_back (
61- ShapedType::isDynamic (s)
62- ? *(dyn++)
63- : b.create <arith::ConstantOp>(loc, type,
64- i64 == type ? b. getI64IntegerAttr (s)
65- : b. getIndexAttr (s)));
62+ if (s == ShapedType:: kDynamic ) {
63+ values. emplace_back (*(dyn++));
64+ } else {
65+ TypedAttr val = type == i64 ? b.getI64IntegerAttr (s) : b. getIndexAttr (s);
66+ values. emplace_back (b. create <arith::ConstantOp>(loc, type, val));
67+ }
6668 }
6769 return values;
6870};
@@ -129,33 +131,33 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
129131 ConversionPatternRewriter &rewriter) const override {
130132 auto splitAxes = op.getSplitAxes ().getAxes ();
131133 int64_t maxNAxes = 0 ;
132- for (auto axes : splitAxes) {
134+ for (auto axes : splitAxes)
133135 maxNAxes = std::max<int64_t >(maxNAxes, axes.size ());
134- }
135136
136137 // To hold the split axes, create empty 2d tensor with shape
137138 // {splitAxes.size(), max-size-of-split-groups}.
138139 // Set trailing elements for smaller split-groups to -1.
139140 Location loc = op.getLoc ();
140141 auto i16 = rewriter.getI16Type ();
141142 auto i64 = rewriter.getI64Type ();
142- int64_t shape[] = {static_cast <int64_t >(splitAxes.size ()), maxNAxes};
143+ std::array<int64_t , 2 > shape = {static_cast <int64_t >(splitAxes.size ()),
144+ maxNAxes};
143145 Value resSplitAxes = rewriter.create <tensor::EmptyOp>(loc, shape, i16 );
144- auto attr = IntegerAttr::get (i16 , 0xffff );
146+ auto attr = IntegerAttr::get (i16 , - 1 );
145147 Value fillValue = rewriter.create <arith::ConstantOp>(loc, i16 , attr);
146148 resSplitAxes = rewriter.create <linalg::FillOp>(loc, fillValue, resSplitAxes)
147149 .getResult (0 );
148150
149151 // explicitly write values into tensor row by row
150- int64_t strides[] = {1 , 1 };
152+ std::array< int64_t , 2 > strides = {1 , 1 };
151153 int64_t nSplits = 0 ;
152154 ValueRange empty = {};
153155 for (auto [i, axes] : llvm::enumerate (splitAxes)) {
154156 int64_t size = axes.size ();
155157 if (size > 0 )
156158 ++nSplits;
157- int64_t offs[] = {(int64_t )i, 0 };
158- int64_t sizes[] = {1 , size};
159+ std::array< int64_t , 2 > offs = {(int64_t )i, 0 };
160+ std::array< int64_t , 2 > sizes = {1 , size};
159161 auto tensorType = RankedTensorType::get ({size}, i16 );
160162 auto attrs = DenseIntElementsAttr::get (tensorType, axes.asArrayRef ());
161163 auto vals = rewriter.create <arith::ConstantOp>(loc, tensorType, attrs);
@@ -165,7 +167,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
165167
166168 // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
167169 // Store the halo sizes in the tensor.
168- auto haloSizes =
170+ SmallVector<Value> haloSizes =
169171 getMixedAsValues (rewriter, loc, adaptor.getStaticHaloSizes (),
170172 adaptor.getDynamicHaloSizes ());
171173 auto type = RankedTensorType::get ({nSplits, 2 }, i64 );
@@ -190,7 +192,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
190192 } else {
191193 SymbolTableCollection symbolTableCollection;
192194 auto meshOp = getMesh (op, symbolTableCollection);
193- auto maxSplitSize = 0 ;
195+ int64_t maxSplitSize = 0 ;
194196 for (auto axes : splitAxes) {
195197 int64_t splitSize =
196198 collectiveProcessGroupSize (axes.asArrayRef (), meshOp.getShape ());
@@ -206,7 +208,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
206208 loc, i64 , rewriter.getI64IntegerAttr (ShapedType::kDynamic ));
207209 resOffsets =
208210 rewriter.create <linalg::FillOp>(loc, zero, resOffsets).getResult (0 );
209- auto offsets =
211+ SmallVector<Value> offsets =
210212 getMixedAsValues (rewriter, loc, adaptor.getStaticShardedDimsOffsets (),
211213 adaptor.getDynamicShardedDimsOffsets ());
212214 int64_t curr = 0 ;
@@ -217,8 +219,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
217219 ++splitSize; // add one for the total size
218220 ArrayRef<Value> values (&offsets[curr], splitSize);
219221 Value vals = rewriter.create <tensor::FromElementsOp>(loc, values);
220- int64_t offs[] = {( int64_t )i , 0 };
221- int64_t sizes[] = {1 , splitSize};
222+ std::array< int64_t , 2 > offs = {static_cast < int64_t >(i) , 0 };
223+ std::array< int64_t , 2 > sizes = {1 , splitSize};
222224 resOffsets = rewriter.create <tensor::InsertSliceOp>(
223225 loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
224226 curr += splitSize;
@@ -275,9 +277,9 @@ struct ConvertProcessMultiIndexOp
275277 if (!axes.empty ()) {
276278 SmallVector<Value> subIndex;
277279 for (auto axis : axes) {
278- subIndex.push_back (mIdx [axis]);
280+ subIndex.emplace_back (mIdx [axis]);
279281 }
280- mIdx = subIndex;
282+ mIdx = std::move ( subIndex) ;
281283 }
282284
283285 rewriter.replaceOp (op, mIdx );
@@ -294,8 +296,8 @@ class ConvertProcessLinearIndexOp
294296
295297 // Constructor accepting worldRank
296298 ConvertProcessLinearIndexOp (const TypeConverter &typeConverter,
297- MLIRContext *context, int64_t worldRank_ = -1 )
298- : OpConversionPattern(typeConverter, context), worldRank(worldRank_ ) {}
299+ MLIRContext *context, int64_t worldRank = -1 )
300+ : OpConversionPattern(typeConverter, context), worldRank(worldRank ) {}
299301
300302 LogicalResult
301303 matchAndRewrite (ProcessLinearIndexOp op, OpAdaptor adaptor,
@@ -308,12 +310,11 @@ class ConvertProcessLinearIndexOp
308310 }
309311
310312 // Otherwise call create mpi::CommRankOp
311- auto rank =
312- rewriter
313- .create <mpi::CommRankOp>(
314- op.getLoc (), TypeRange{mpi::RetvalType::get (op->getContext ()),
313+ auto rank = rewriter
314+ .create <mpi::CommRankOp>(
315+ loc, TypeRange{mpi::RetvalType::get (op->getContext ()),
315316 rewriter.getI32Type ()})
316- .getRank ();
317+ .getRank ();
317318 rewriter.replaceOpWithNewOp <arith::IndexCastOp>(op, rewriter.getIndexType (),
318319 rank);
319320 return success ();
@@ -400,11 +401,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
400401 }
401402
402403 // Compute the sharded shape by applying the sharding to the input shape.
403- // Without shardedDimsOffsets in the sharding, the shard shape is computed
404- // by dividing the dimension size by the number of shards in that dimension
405- // (which is given by the size of the mesh axes provided in split-axes).
406- // Odd elements get distributed to trailing shards.
407- // If a shardedDimsOffsets is provided, the shard shape is computed by
404+ // If shardedDimsOffsets is not defined in the sharding, the shard shape is
405+ // computed by dividing the dimension size by the number of shards in that
406+ // dimension (which is given by the size of the mesh axes provided in
407+ // split-axes). Odd elements get distributed to trailing shards. If a
408+ // shardedDimsOffsets is provided, the shard shape is computed by
408409 // subtracting the offset of the current shard from the offset of the next
409410 // shard.
410411
@@ -429,8 +430,9 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
429430
430431 // To keep the code simple, convert dims/device to values when they are
431432 // attributes. Count on canonicalization to fold static values.
432- auto shape = getMixedAsValues (rewriter, loc, op.getDims (), dynDims, index);
433- auto multiIdx =
433+ SmallVector<Value> shape =
434+ getMixedAsValues (rewriter, loc, op.getDims (), dynDims, index);
435+ SmallVector<Value> multiIdx =
434436 getMixedAsValues (rewriter, loc, adaptor.getDevice (), dynDevice, index);
435437
436438 // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
@@ -448,7 +450,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
448450 // local shard-size.
449451 Value shardedDimsOffs;
450452 {
451- auto tmp = getMixedAsValues (
453+ SmallVector<Value> tmp = getMixedAsValues (
452454 rewriter, loc, sharding.getStaticShardedDimsOffsets (),
453455 sharding.getDynamicShardedDimsOffsets (), index);
454456 if (!tmp.empty ())
@@ -478,7 +480,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
478480 rewriter.create <arith::ConstantOp>(loc, rewriter.getIndexAttr (pos));
479481 // Get the index of the local shard in the mesh axis.
480482 Value idx = multiIdx[axes[0 ]];
481- auto _numShards =
483+ auto numShards =
482484 collectiveProcessGroupSize (axes.asArrayRef (), meshOp.getShape ());
483485 if (shardedDimsOffs) {
484486 // If sharded dims offsets are provided, use them to compute the
@@ -497,22 +499,22 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
497499 Value sz = rewriter.create <arith::SubIOp>(loc, nextOff, off);
498500 shardShape.emplace_back (sz);
499501 } else {
500- auto numShards = rewriter.create <arith::ConstantOp>(
501- loc, rewriter.getIndexAttr (_numShards ));
502+ Value numShardsVal = rewriter.create <arith::ConstantOp>(
503+ loc, rewriter.getIndexAttr (numShards ));
502504 // Compute shard dim size by distributing odd elements to trailing
503505 // shards:
504506 // sz = dim / numShards
505507 // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
506- Value sz = rewriter.create <arith::DivSIOp>(loc, dim, numShards );
507- Value sz1 = rewriter.create <arith::RemSIOp>(loc, dim, numShards );
508- sz1 = rewriter.create <arith::SubIOp>(loc, numShards , sz1);
508+ Value sz = rewriter.create <arith::DivSIOp>(loc, dim, numShardsVal );
509+ Value sz1 = rewriter.create <arith::RemSIOp>(loc, dim, numShardsVal );
510+ sz1 = rewriter.create <arith::SubIOp>(loc, numShardsVal , sz1);
509511 auto cond = rewriter.create <arith::CmpIOp>(
510512 loc, arith::CmpIPredicate::sge, idx, sz1);
511513 Value odd = rewriter.create <arith::SelectOp>(loc, cond, one, zero);
512514 sz = rewriter.create <arith::AddIOp>(loc, sz, odd);
513515 shardShape.emplace_back (sz);
514516 }
515- pos += _numShards + 1 ; // add one for the total size.
517+ pos += numShards + 1 ; // add one for the total size.
516518 } // else no sharding if split axis is empty or no split axis
517519 // If no size was added -> no sharding in this dimension.
518520 if (shardShape.size () <= i)
@@ -698,25 +700,24 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
698700 offsets[dim] = orgOffset;
699701 };
700702
701- auto get_i32val = [&](OpFoldResult &v) {
702- return isa<Value>(v)
703- ? cast<Value>(v)
704- : rewriter.create <arith::ConstantOp>(
705- loc,
706- rewriter.getI32IntegerAttr (
707- cast<IntegerAttr>(cast<Attribute>(v)).getInt ()));
708- };
709-
710- for (int i = 0 ; i < 2 ; ++i) {
711- Value haloSz = get_i32val (haloSizes[currHaloDim * 2 + i]);
703+ auto doSendRecv = [&](int upOrDown) {
704+ OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
705+ Value haloSz = dyn_cast<Value>(v);
706+ if (!haloSz)
707+ haloSz = rewriter.create <arith::ConstantOp>(
708+ loc, rewriter.getI32IntegerAttr (
709+ cast<IntegerAttr>(cast<Attribute>(v)).getInt ()));
712710 auto hasSize = rewriter.create <arith::CmpIOp>(
713711 loc, arith::CmpIPredicate::sgt, haloSz, zero);
714712 rewriter.create <scf::IfOp>(loc, hasSize,
715713 [&](OpBuilder &builder, Location loc) {
716- genSendRecv (i > 0 );
714+ genSendRecv (upOrDown > 0 );
717715 builder.create <scf::YieldOp>(loc);
718716 });
719- }
717+ };
718+
719+ doSendRecv (0 );
720+ doSendRecv (1 );
720721
721722 // the shape for lower dims include higher dims' halos
722723 dimSizes[dim] = shape[dim];
@@ -775,8 +776,8 @@ struct ConvertMeshToMPIPass
775776 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
776777 auto i16 = IntegerType::get (type.getContext (), 16 );
777778 auto i64 = IntegerType::get (type.getContext (), 64 );
778- std::array<int64_t , 2 > shp{ShapedType::kDynamic ,
779- ShapedType::kDynamic };
779+ std::array<int64_t , 2 > shp = {ShapedType::kDynamic ,
780+ ShapedType::kDynamic };
780781 results.emplace_back (RankedTensorType::get (shp, i16 ));
781782 results.emplace_back (RankedTensorType::get (shp, i64 )); // actually ?x2
782783 results.emplace_back (RankedTensorType::get (shp, i64 ));
0 commit comments