1414
1515#include " mlir/Dialect/Arith/IR/Arith.h"
1616#include " mlir/Dialect/Bufferization/IR/Bufferization.h"
17+ #include " mlir/Dialect/DLTI/DLTI.h"
18+ #include " mlir/Dialect/Func/IR/FuncOps.h"
19+ #include " mlir/Dialect/Func/Transforms/FuncConversions.h"
20+ #include " mlir/Dialect/Linalg/IR/Linalg.h"
1721#include " mlir/Dialect/MPI/IR/MPI.h"
1822#include " mlir/Dialect/MemRef/IR/MemRef.h"
23+ #include " mlir/Dialect/Mesh/IR/MeshDialect.h"
1924#include " mlir/Dialect/Mesh/IR/MeshOps.h"
2025#include " mlir/Dialect/SCF/IR/SCF.h"
2126#include " mlir/Dialect/Tensor/IR/Tensor.h"
2530#include " mlir/IR/BuiltinTypes.h"
2631#include " mlir/IR/PatternMatch.h"
2732#include " mlir/IR/SymbolTable.h"
33+ #include " mlir/Transforms/DialectConversion.h"
2834#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2935
3036#define DEBUG_TYPE " mesh-to-mpi"
@@ -36,10 +42,32 @@ namespace mlir {
3642} // namespace mlir
3743
3844using namespace mlir ;
39- using namespace mlir :: mesh;
45+ using namespace mesh ;
4046
4147namespace {
42- // Create operations converting a linear index to a multi-dimensional index
48+ // / Convert vec of OpFoldResults (ints) into vector of Values.
49+ static SmallVector<Value> getMixedAsValues (OpBuilder b, const Location &loc,
50+ llvm::ArrayRef<int64_t > statics,
51+ ValueRange dynamics,
52+ Type type = Type()) {
53+ SmallVector<Value> values;
54+ auto dyn = dynamics.begin ();
55+ Type i64 = b.getI64Type ();
56+ if (!type)
57+ type = i64 ;
58+ assert (i64 == type || b.getIndexType () == type);
59+ for (auto s : statics) {
60+ values.emplace_back (
61+ ShapedType::isDynamic (s)
62+ ? *(dyn++)
63+ : b.create <arith::ConstantOp>(loc, type,
64+ i64 == type ? b.getI64IntegerAttr (s)
65+ : b.getIndexAttr (s)));
66+ }
67+ return values;
68+ };
69+
70+ // / Create operations converting a linear index to a multi-dimensional index.
4371static SmallVector<Value> linearToMultiIndex (Location loc, OpBuilder b,
4472 Value linearIndex,
4573 ValueRange dimensions) {
@@ -72,6 +100,152 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
72100 return linearIndex;
73101}
74102
103+ // / Replace GetShardingOp with related/dependent ShardingOp.
104+ struct ConvertGetShardingOp : public OpConversionPattern <GetShardingOp> {
105+ using OpConversionPattern::OpConversionPattern;
106+
107+ LogicalResult
108+ matchAndRewrite (GetShardingOp op, OpAdaptor adaptor,
109+ ConversionPatternRewriter &rewriter) const override {
110+ auto shardOp = adaptor.getSource ().getDefiningOp <ShardOp>();
111+ if (!shardOp)
112+ return failure ();
113+ auto shardingOp = shardOp.getSharding ().getDefiningOp <ShardingOp>();
114+ if (!shardingOp)
115+ return failure ();
116+
117+ rewriter.replaceOp (op, shardingOp.getResult ());
118+ return success ();
119+ }
120+ };
121+
122+ // / Convert a sharding op to a tuple of tensors of its components
123+ // / (SplitAxes, HaloSizes, ShardedDimsOffsets)
124+ // / as defined by type converter.
125+ struct ConvertShardingOp : public OpConversionPattern <ShardingOp> {
126+ using OpConversionPattern::OpConversionPattern;
127+
128+ LogicalResult
129+ matchAndRewrite (ShardingOp op, OpAdaptor adaptor,
130+ ConversionPatternRewriter &rewriter) const override {
131+ auto splitAxes = op.getSplitAxes ().getAxes ();
132+ int64_t maxNAxes = 0 ;
133+ for (auto axes : splitAxes) {
134+ maxNAxes = std::max<int64_t >(maxNAxes, axes.size ());
135+ }
136+
137+ // To hold the split axes, create empty 2d tensor with shape
138+ // {splitAxes.size(), max-size-of-split-groups}.
139+ // Set trailing elements for smaller split-groups to -1.
140+ Location loc = op.getLoc ();
141+ auto i16 = rewriter.getI16Type ();
142+ auto i64 = rewriter.getI64Type ();
143+ int64_t shape[] = {static_cast <int64_t >(splitAxes.size ()), maxNAxes};
144+ Value resSplitAxes = rewriter.create <tensor::EmptyOp>(loc, shape, i16 );
145+ auto attr = IntegerAttr::get (i16 , 0xffff );
146+ Value fillValue = rewriter.create <arith::ConstantOp>(loc, i16 , attr);
147+ resSplitAxes = rewriter.create <linalg::FillOp>(loc, fillValue, resSplitAxes)
148+ .getResult (0 );
149+
150+ // explicitly write values into tensor row by row
151+ int64_t strides[] = {1 , 1 };
152+ int64_t nSplits = 0 ;
153+ ValueRange empty = {};
154+ for (auto [i, axes] : llvm::enumerate (splitAxes)) {
155+ int64_t size = axes.size ();
156+ if (size > 0 )
157+ ++nSplits;
158+ int64_t offs[] = {(int64_t )i, 0 };
159+ int64_t sizes[] = {1 , size};
160+ auto tensorType = RankedTensorType::get ({size}, i16 );
161+ auto attrs = DenseIntElementsAttr::get (tensorType, axes.asArrayRef ());
162+ auto vals = rewriter.create <arith::ConstantOp>(loc, tensorType, attrs);
163+ resSplitAxes = rewriter.create <tensor::InsertSliceOp>(
164+ loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
165+ }
166+
167+ // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
168+ // Store the halo sizes in the tensor.
169+ auto haloSizes =
170+ getMixedAsValues (rewriter, loc, adaptor.getStaticHaloSizes (),
171+ adaptor.getDynamicHaloSizes ());
172+ auto type = RankedTensorType::get ({nSplits, 2 }, i64 );
173+ Value resHaloSizes =
174+ haloSizes.empty ()
175+ ? rewriter
176+ .create <tensor::EmptyOp>(loc, std::array<int64_t , 2 >{0 , 0 },
177+ i64 )
178+ .getResult ()
179+ : rewriter.create <tensor::FromElementsOp>(loc, type, haloSizes)
180+ .getResult ();
181+
182+ // To hold sharded dims offsets, create Tensor with shape {nSplits,
183+ // maxSplitSize+1}. Store the offsets in the tensor but set trailing
184+ // elements for smaller split-groups to -1. Computing the max size of the
185+ // split groups needs using collectiveProcessGroupSize (which needs the
186+ // MeshOp)
187+ Value resOffsets;
188+ if (adaptor.getStaticShardedDimsOffsets ().empty ()) {
189+ resOffsets = rewriter.create <tensor::EmptyOp>(
190+ loc, std::array<int64_t , 2 >{0 , 0 }, i64 );
191+ } else {
192+ SymbolTableCollection symbolTableCollection;
193+ auto meshOp = getMesh (op, symbolTableCollection);
194+ auto maxSplitSize = 0 ;
195+ for (auto axes : splitAxes) {
196+ int64_t splitSize =
197+ collectiveProcessGroupSize (axes.asArrayRef (), meshOp.getShape ());
198+ assert (splitSize != ShapedType::kDynamic );
199+ maxSplitSize = std::max<int64_t >(maxSplitSize, splitSize);
200+ }
201+ assert (maxSplitSize);
202+ ++maxSplitSize; // add one for the total size
203+
204+ resOffsets = rewriter.create <tensor::EmptyOp>(
205+ loc, std::array<int64_t , 2 >{nSplits, maxSplitSize}, i64 );
206+ Value zero = rewriter.create <arith::ConstantOp>(
207+ loc, i64 , rewriter.getI64IntegerAttr (ShapedType::kDynamic ));
208+ resOffsets =
209+ rewriter.create <linalg::FillOp>(loc, zero, resOffsets).getResult (0 );
210+ auto offsets =
211+ getMixedAsValues (rewriter, loc, adaptor.getStaticShardedDimsOffsets (),
212+ adaptor.getDynamicShardedDimsOffsets ());
213+ int64_t curr = 0 ;
214+ for (auto [i, axes] : llvm::enumerate (splitAxes)) {
215+ int64_t splitSize =
216+ collectiveProcessGroupSize (axes.asArrayRef (), meshOp.getShape ());
217+ assert (splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
218+ ++splitSize; // add one for the total size
219+ ArrayRef<Value> values (&offsets[curr], splitSize);
220+ Value vals = rewriter.create <tensor::FromElementsOp>(loc, values);
221+ int64_t offs[] = {(int64_t )i, 0 };
222+ int64_t sizes[] = {1 , splitSize};
223+ resOffsets = rewriter.create <tensor::InsertSliceOp>(
224+ loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
225+ curr += splitSize;
226+ }
227+ }
228+
229+ // return a tuple of tensors as defined by type converter
230+ SmallVector<Type> resTypes;
231+ if (failed (getTypeConverter ()->convertType (op.getResult ().getType (),
232+ resTypes)))
233+ return failure ();
234+
235+ resSplitAxes =
236+ rewriter.create <tensor::CastOp>(loc, resTypes[0 ], resSplitAxes);
237+ resHaloSizes =
238+ rewriter.create <tensor::CastOp>(loc, resTypes[1 ], resHaloSizes);
239+ resOffsets = rewriter.create <tensor::CastOp>(loc, resTypes[2 ], resOffsets);
240+
241+ rewriter.replaceOpWithNewOp <UnrealizedConversionCastOp>(
242+ op, TupleType::get (op.getContext (), resTypes),
243+ ValueRange{resSplitAxes, resHaloSizes, resOffsets});
244+
245+ return success ();
246+ }
247+ };
248+
75249struct ConvertProcessMultiIndexOp
76250 : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
77251 using OpRewritePattern::OpRewritePattern;
@@ -419,14 +593,95 @@ struct ConvertMeshToMPIPass
419593
420594 // / Run the dialect converter on the module.
421595 void runOnOperation () override {
422- auto *ctx = &getContext ();
423- mlir::RewritePatternSet patterns (ctx);
596+ uint64_t worldRank = -1 ;
597+ // Try to get DLTI attribute for MPI:comm_world_rank
598+ // If found, set worldRank to the value of the attribute.
599+ {
600+ auto dltiAttr =
601+ dlti::query (getOperation (), {" MPI:comm_world_rank" }, false );
602+ if (succeeded (dltiAttr)) {
603+ if (!isa<IntegerAttr>(dltiAttr.value ())) {
604+ getOperation ()->emitError ()
605+ << " Expected an integer attribute for MPI:comm_world_rank" ;
606+ return signalPassFailure ();
607+ }
608+ worldRank = cast<IntegerAttr>(dltiAttr.value ()).getInt ();
609+ }
610+ }
424611
425- patterns.insert <ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
426- ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
427- ctx);
612+ auto *ctxt = &getContext ();
613+ RewritePatternSet patterns (ctxt);
614+ ConversionTarget target (getContext ());
615+
616+ // Define a type converter to convert mesh::ShardingType,
617+ // mostly for use in return operations.
618+ TypeConverter typeConverter;
619+ typeConverter.addConversion ([](Type type) { return type; });
620+
621+ // convert mesh::ShardingType to a tuple of RankedTensorTypes
622+ typeConverter.addConversion (
623+ [](ShardingType type,
624+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
625+ auto i16 = IntegerType::get (type.getContext (), 16 );
626+ auto i64 = IntegerType::get (type.getContext (), 64 );
627+ std::array<int64_t , 2 > shp{ShapedType::kDynamic ,
628+ ShapedType::kDynamic };
629+ results.emplace_back (RankedTensorType::get (shp, i16 ));
630+ results.emplace_back (RankedTensorType::get (shp, i64 )); // actually ?x2
631+ results.emplace_back (RankedTensorType::get (shp, i64 ));
632+ return success ();
633+ });
634+
635+ // To 'extract' components, a UnrealizedConversionCastOp is expected
636+ // to define the input
637+ typeConverter.addTargetMaterialization (
638+ [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
639+ Location loc) {
640+ // Expecting a single input.
641+ if (inputs.size () != 1 || !isa<TupleType>(inputs[0 ].getType ()))
642+ return SmallVector<Value>();
643+ auto castOp = inputs[0 ].getDefiningOp <UnrealizedConversionCastOp>();
644+ // Expecting an UnrealizedConversionCastOp.
645+ if (!castOp)
646+ return SmallVector<Value>();
647+ // Fill a vector with elements of the tuple/castOp.
648+ SmallVector<Value> results;
649+ for (auto oprnd : castOp.getInputs ()) {
650+ if (!isa<RankedTensorType>(oprnd.getType ()))
651+ return SmallVector<Value>();
652+ results.emplace_back (oprnd);
653+ }
654+ return results;
655+ });
428656
429- (void )mlir::applyPatternsGreedily (getOperation (), std::move (patterns));
657+ // No mesh dialect should left after conversion...
658+ target.addIllegalDialect <mesh::MeshDialect>();
659+ // ...except the global MeshOp
660+ target.addLegalOp <mesh::MeshOp>();
661+ // Allow all the stuff that our patterns will convert to
662+ target.addLegalDialect <BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
663+ arith::ArithDialect, tensor::TensorDialect,
664+ bufferization::BufferizationDialect,
665+ linalg::LinalgDialect, memref::MemRefDialect>();
666+ // Make sure the function signature, calls etc. are legal
667+ target.addDynamicallyLegalOp <func::FuncOp>([&](func::FuncOp op) {
668+ return typeConverter.isSignatureLegal (op.getFunctionType ());
669+ });
670+ target.addDynamicallyLegalOp <func::CallOp, func::ReturnOp>(
671+ [&](Operation *op) { return typeConverter.isLegal (op); });
672+
673+ patterns.add <ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
674+ ConvertProcessMultiIndexOp, ConvertGetShardingOp,
675+ ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
676+ // ConvertProcessLinearIndexOp accepts an optional worldRank
677+ patterns.add <ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
678+
679+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
680+ patterns, typeConverter);
681+ populateCallOpTypeConversionPattern (patterns, typeConverter);
682+ populateReturnOpTypeConversionPattern (patterns, typeConverter);
683+
684+ (void )applyPartialConversion (getOperation (), target, std::move (patterns));
430685 }
431686};
432687
0 commit comments