1919#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
2020#include " mlir/Dialect/MPI/IR/MPI.h"
2121#include " mlir/Transforms/DialectConversion.h"
22+ #include < memory>
2223
2324using namespace mlir ;
2425
2526namespace {
2627
2728template <typename Op, typename ... Args>
28- static Op getOrDefineGlobal (mlir:: ModuleOp &moduleOp, const Location loc,
29+ static Op getOrDefineGlobal (ModuleOp &moduleOp, const Location loc,
2930 ConversionPatternRewriter &rewriter, StringRef name,
3031 Args &&...args) {
3132 Op ret;
@@ -46,11 +47,40 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
4647 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
4748}
4849
50+ // / When lowering the mpi dialect to functions calls certain details
51+ // / differ between various MPI implementations. This class will provide
52+ // / these in a generic way, depending on the MPI implementation that got
53+ // / selected by the DLTI attribute on the module.
54+ class MPIImplTraits {
55+ ModuleOp &moduleOp;
56+
57+ public:
58+ // / Instantiate a new MPIImplTraits object according to the DLTI attribute
59+ // / on the given module.
60+ static std::unique_ptr<MPIImplTraits> get (ModuleOp &moduleOp);
61+
62+ MPIImplTraits (ModuleOp &moduleOp) : moduleOp(moduleOp) {}
63+
64+ ModuleOp &getModuleOp () { return moduleOp; }
65+
66+ // / Gets or creates MPI_COMM_WORLD as a Value.
67+ virtual Value getCommWorld (const Location loc,
68+ ConversionPatternRewriter &rewriter) = 0;
69+
70+ // / Get the MPI_STATUS_IGNORE value (typically a pointer type).
71+ virtual intptr_t getStatusIgnore () = 0;
72+
73+ // / get/create MPI datatype as a Value which corresponds to the given
74+ // / Type
75+ virtual Value getDataType (const Location loc,
76+ ConversionPatternRewriter &rewriter, Type type) = 0;
77+ };
78+
4979// ===----------------------------------------------------------------------===//
5080// Implementation details for MPICH ABI compatible MPI implementations
5181// ===----------------------------------------------------------------------===//
5282
53- struct MPICHImplTraits {
83+ class MPICHImplTraits : public MPIImplTraits {
5484 static constexpr int MPI_FLOAT = 0x4c00040a ;
5585 static constexpr int MPI_DOUBLE = 0x4c00080b ;
5686 static constexpr int MPI_INT8_T = 0x4c000137 ;
@@ -62,20 +92,20 @@ struct MPICHImplTraits {
6292 static constexpr int MPI_UINT32_T = 0x4c00043d ;
6393 static constexpr int MPI_UINT64_T = 0x4c00083e ;
6494
65- static mlir::Value getCommWorld (mlir::ModuleOp &moduleOp,
66- const mlir::Location loc,
67- mlir::ConversionPatternRewriter &rewriter) {
95+ public:
96+ using MPIImplTraits::MPIImplTraits;
97+
98+ Value getCommWorld (const Location loc,
99+ ConversionPatternRewriter &rewriter) override {
68100 static const int MPI_COMM_WORLD = 0x44000000 ;
69- return rewriter.create <mlir:: LLVM::ConstantOp>(loc, rewriter.getI32Type (),
70- MPI_COMM_WORLD);
101+ return rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI32Type (),
102+ MPI_COMM_WORLD);
71103 }
72104
73- static intptr_t getStatusIgnore () { return 1 ; }
105+ intptr_t getStatusIgnore () override { return 1 ; }
74106
75- static mlir::Value getDataType (mlir::ModuleOp &moduleOp,
76- const mlir::Location loc,
77- mlir::ConversionPatternRewriter &rewriter,
78- mlir::Type type) {
107+ Value getDataType (const Location loc, ConversionPatternRewriter &rewriter,
108+ Type type) override {
79109 int32_t mtype = 0 ;
80110 if (type.isF32 ())
81111 mtype = MPI_FLOAT;
@@ -99,53 +129,50 @@ struct MPICHImplTraits {
99129 mtype = MPI_UINT8_T;
100130 else
101131 assert (false && " unsupported type" );
102- return rewriter.create <mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type (),
103- mtype);
132+ return rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI32Type (), mtype);
104133 }
105134};
106135
107136// ===----------------------------------------------------------------------===//
108137// Implementation details for OpenMPI
109138// ===----------------------------------------------------------------------===//
110- struct OMPIImplTraits {
111-
112- static mlir::LLVM::GlobalOp
113- getOrDefineExternalStruct (mlir::ModuleOp &moduleOp, const mlir::Location loc,
114- mlir::ConversionPatternRewriter &rewriter,
115- mlir::StringRef name,
116- mlir::LLVM::LLVMStructType type) {
117-
118- return getOrDefineGlobal<mlir::LLVM::GlobalOp>(
119- moduleOp, loc, rewriter, name, type, /* isConstant=*/ false ,
120- mlir::LLVM::Linkage::External, name,
121- /* value=*/ mlir::Attribute (), /* alignment=*/ 0 , 0 );
139+ class OMPIImplTraits : public MPIImplTraits {
140+ LLVM::GlobalOp getOrDefineExternalStruct (const Location loc,
141+ ConversionPatternRewriter &rewriter,
142+ StringRef name,
143+ LLVM::LLVMStructType type) {
144+
145+ return getOrDefineGlobal<LLVM::GlobalOp>(
146+ getModuleOp (), loc, rewriter, name, type, /* isConstant=*/ false ,
147+ LLVM::Linkage::External, name,
148+ /* value=*/ Attribute (), /* alignment=*/ 0 , 0 );
122149 }
123150
124- static mlir::Value getCommWorld (mlir::ModuleOp &moduleOp,
125- const mlir::Location loc,
126- mlir::ConversionPatternRewriter &rewriter) {
151+ public:
152+ using MPIImplTraits::MPIImplTraits;
153+
154+ Value getCommWorld (const Location loc,
155+ ConversionPatternRewriter &rewriter) override {
127156 auto context = rewriter.getContext ();
128157 // get external opaque struct pointer type
129158 auto commStructT =
130- mlir:: LLVM::LLVMStructType::getOpaque (" ompi_communicator_t" , context);
131- mlir:: StringRef name = " ompi_mpi_comm_world" ;
159+ LLVM::LLVMStructType::getOpaque (" ompi_communicator_t" , context);
160+ StringRef name = " ompi_mpi_comm_world" ;
132161
133162 // make sure global op definition exists
134- (void )getOrDefineExternalStruct (moduleOp, loc, rewriter, name, commStructT);
163+ (void )getOrDefineExternalStruct (loc, rewriter, name, commStructT);
135164
136165 // get address of symbol
137- return rewriter.create <mlir:: LLVM::AddressOfOp>(
138- loc, mlir:: LLVM::LLVMPointerType::get (context),
139- mlir:: SymbolRefAttr::get (context, name));
166+ return rewriter.create <LLVM::AddressOfOp>(
167+ loc, LLVM::LLVMPointerType::get (context),
168+ SymbolRefAttr::get (context, name));
140169 }
141170
142- static intptr_t getStatusIgnore () { return 0 ; }
171+ intptr_t getStatusIgnore () override { return 0 ; }
143172
144- static mlir::Value getDataType (mlir::ModuleOp &moduleOp,
145- const mlir::Location loc,
146- mlir::ConversionPatternRewriter &rewriter,
147- mlir::Type type) {
148- mlir::StringRef mtype;
173+ Value getDataType (const Location loc, ConversionPatternRewriter &rewriter,
174+ Type type) override {
175+ StringRef mtype;
149176 if (type.isF32 ())
150177 mtype = " ompi_mpi_float" ;
151178 else if (type.isF64 ())
@@ -171,67 +198,29 @@ struct OMPIImplTraits {
171198
172199 auto context = rewriter.getContext ();
173200 // get external opaque struct pointer type
174- auto commStructT = mlir::LLVM::LLVMStructType::getOpaque (
175- " ompi_predefined_datatype_t" , context);
201+ auto commStructT =
202+ LLVM::LLVMStructType::getOpaque ( " ompi_predefined_datatype_t" , context);
176203 // make sure global op definition exists
177- (void )getOrDefineExternalStruct (moduleOp, loc, rewriter, mtype,
178- commStructT);
204+ (void )getOrDefineExternalStruct (loc, rewriter, mtype, commStructT);
179205 // get address of symbol
180- return rewriter.create <mlir:: LLVM::AddressOfOp>(
181- loc, mlir:: LLVM::LLVMPointerType::get (context),
182- mlir:: SymbolRefAttr::get (context, mtype));
206+ return rewriter.create <LLVM::AddressOfOp>(
207+ loc, LLVM::LLVMPointerType::get (context),
208+ SymbolRefAttr::get (context, mtype));
183209 }
184210};
185211
186- // / When lowering the mpi dialect to functions calls certain details
187- // / differ between various MPI implementations. This class will provide
188- // / these in a generic way, depending on the MPI implementation that got
189- // / selected by the DLTI attribute on the module.
190- struct MPIImplTraits {
191- enum MPIImpl { MPICH, OMPI };
192-
193- // / Gets the MPI implementation from a DLTI attribute on the module.
194- // / Defaults to MPICH (and ABI compatible).
195- static MPIImpl getMPIImpl (mlir::ModuleOp &moduleOp) {
196- auto attr = dlti::query (*&moduleOp, {" MPI:Implementation" }, true );
197- if (failed (attr))
198- return MPICH;
199- auto strAttr = dyn_cast<StringAttr>(attr.value ());
200- if (strAttr && strAttr.getValue () == " OpenMPI" )
201- return OMPI;
202- if (!strAttr || strAttr.getValue () != " MPICH" )
203- moduleOp.emitWarning () << " Unknown \" MPI:Implementation\" value in DLTI ("
204- << strAttr.getValue () << " ), defaulting to MPICH" ;
205- return MPICH;
206- }
207-
208- // / Gets or creates MPI_COMM_WORLD as a mlir::Value.
209- static mlir::Value getCommWorld (mlir::ModuleOp &moduleOp,
210- const mlir::Location loc,
211- mlir::ConversionPatternRewriter &rewriter) {
212- if (MPIImplTraits::getMPIImpl (moduleOp) == OMPI)
213- return OMPIImplTraits::getCommWorld (moduleOp, loc, rewriter);
214- return MPICHImplTraits::getCommWorld (moduleOp, loc, rewriter);
215- }
216-
217- // / Get the MPI_STATUS_IGNORE value (typically a pointer type).
218- static intptr_t getStatusIgnore (mlir::ModuleOp &moduleOp) {
219- if (MPIImplTraits::getMPIImpl (moduleOp) == OMPI)
220- return OMPIImplTraits::getStatusIgnore ();
221- return MPICHImplTraits::getStatusIgnore ();
222- }
223-
224- // / get/create MPI datatype as a mlir::Value which corresponds to the given
225- // / mlir::Type
226- static mlir::Value getDataType (mlir::ModuleOp &moduleOp,
227- const mlir::Location loc,
228- mlir::ConversionPatternRewriter &rewriter,
229- mlir::Type type) {
230- if (MPIImplTraits::getMPIImpl (moduleOp) == OMPI)
231- return OMPIImplTraits::getDataType (moduleOp, loc, rewriter, type);
232- return MPICHImplTraits::getDataType (moduleOp, loc, rewriter, type);
233- }
234- };
212+ std::unique_ptr<MPIImplTraits> MPIImplTraits::get (ModuleOp &moduleOp) {
213+ auto attr = dlti::query (*&moduleOp, {" MPI:Implementation" }, true );
214+ if (failed (attr))
215+ return std::make_unique<MPICHImplTraits>(moduleOp);
216+ auto strAttr = dyn_cast<StringAttr>(attr.value ());
217+ if (strAttr && strAttr.getValue () == " OpenMPI" )
218+ return std::make_unique<OMPIImplTraits>(moduleOp);
219+ if (!strAttr || strAttr.getValue () != " MPICH" )
220+ moduleOp.emitWarning () << " Unknown \" MPI:Implementation\" value in DLTI ("
221+ << strAttr.getValue () << " ), defaulting to MPICH" ;
222+ return std::make_unique<MPICHImplTraits>(moduleOp);
223+ }
235224
236225// ===----------------------------------------------------------------------===//
237226// InitOpLowering
@@ -320,8 +309,9 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
320309 // grab a reference to the global module op:
321310 auto moduleOp = op->getParentOfType <ModuleOp>();
322311
312+ auto mpiTraits = MPIImplTraits::get (moduleOp);
323313 // get MPI_COMM_WORLD
324- Value commWorld = MPIImplTraits:: getCommWorld (moduleOp, loc, rewriter);
314+ Value commWorld = mpiTraits-> getCommWorld (loc, rewriter);
325315
326316 // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
327317 auto rankFuncType =
@@ -387,9 +377,9 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
387377 Value size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
388378 ArrayRef<int64_t >{3 , 0 });
389379 size = rewriter.create <LLVM::TruncOp>(loc, i32 , size);
390- Value dataType =
391- MPIImplTraits:: getDataType (moduleOp, loc, rewriter, elemType);
392- Value commWorld = MPIImplTraits:: getCommWorld (moduleOp, loc, rewriter);
380+ auto mpiTraits = MPIImplTraits::get (moduleOp);
381+ Value dataType = mpiTraits-> getDataType (loc, rewriter, elemType);
382+ Value commWorld = mpiTraits-> getCommWorld (loc, rewriter);
393383
394384 // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
395385 // tag, comm)`
@@ -446,11 +436,11 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
446436 Value size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
447437 ArrayRef<int64_t >{3 , 0 });
448438 size = rewriter.create <LLVM::TruncOp>(loc, i32 , size);
449- Value dataType =
450- MPIImplTraits:: getDataType (moduleOp, loc, rewriter, elemType);
451- Value commWorld = MPIImplTraits:: getCommWorld (moduleOp, loc, rewriter);
439+ auto mpiTraits = MPIImplTraits::get (moduleOp);
440+ Value dataType = mpiTraits-> getDataType (loc, rewriter, elemType);
441+ Value commWorld = mpiTraits-> getCommWorld (loc, rewriter);
452442 Value statusIgnore = rewriter.create <LLVM::ConstantOp>(
453- loc, i64 , MPIImplTraits:: getStatusIgnore (moduleOp ));
443+ loc, i64 , mpiTraits-> getStatusIgnore ());
454444 statusIgnore =
455445 rewriter.create <LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
456446
@@ -498,13 +488,13 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
498488// Pattern Population
499489// ===----------------------------------------------------------------------===//
500490
501- void mlir:: mpi::populateMPIToLLVMConversionPatterns (
502- LLVMTypeConverter &converter, RewritePatternSet &patterns) {
491+ void mpi::populateMPIToLLVMConversionPatterns (LLVMTypeConverter &converter,
492+ RewritePatternSet &patterns) {
503493 patterns.add <CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
504494 SendOpLowering, RecvOpLowering>(converter);
505495}
506496
507- void mlir:: mpi::registerConvertMPIToLLVMInterface (DialectRegistry ®istry) {
497+ void mpi::registerConvertMPIToLLVMInterface (DialectRegistry ®istry) {
508498 registry.addExtension (+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
509499 dialect->addInterfaces <FuncToLLVMDialectInterface>();
510500 });
0 commit comments