Skip to content

Commit eaebf5e

Browse files
committed
using virtual dispatch for MPIImplTraits; cleanup
1 parent 2a8745b commit eaebf5e

File tree

1 file changed

+98
-108
lines changed

1 file changed

+98
-108
lines changed

mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Lines changed: 98 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2020
#include "mlir/Dialect/MPI/IR/MPI.h"
2121
#include "mlir/Transforms/DialectConversion.h"
22+
#include <memory>
2223

2324
using namespace mlir;
2425

2526
namespace {
2627

2728
template <typename Op, typename... Args>
28-
static Op getOrDefineGlobal(mlir::ModuleOp &moduleOp, const Location loc,
29+
static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
2930
ConversionPatternRewriter &rewriter, StringRef name,
3031
Args &&...args) {
3132
Op ret;
@@ -46,11 +47,40 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
4647
moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
4748
}
4849

50+
/// When lowering the mpi dialect to functions calls certain details
51+
/// differ between various MPI implementations. This class will provide
52+
/// these in a generic way, depending on the MPI implementation that got
53+
/// selected by the DLTI attribute on the module.
54+
class MPIImplTraits {
55+
ModuleOp &moduleOp;
56+
57+
public:
58+
/// Instantiate a new MPIImplTraits object according to the DLTI attribute
59+
/// on the given module.
60+
static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp);
61+
62+
MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
63+
64+
ModuleOp &getModuleOp() { return moduleOp; }
65+
66+
/// Gets or creates MPI_COMM_WORLD as a Value.
67+
virtual Value getCommWorld(const Location loc,
68+
ConversionPatternRewriter &rewriter) = 0;
69+
70+
/// Get the MPI_STATUS_IGNORE value (typically a pointer type).
71+
virtual intptr_t getStatusIgnore() = 0;
72+
73+
/// get/create MPI datatype as a Value which corresponds to the given
74+
/// Type
75+
virtual Value getDataType(const Location loc,
76+
ConversionPatternRewriter &rewriter, Type type) = 0;
77+
};
78+
4979
//===----------------------------------------------------------------------===//
5080
// Implementation details for MPICH ABI compatible MPI implementations
5181
//===----------------------------------------------------------------------===//
5282

53-
struct MPICHImplTraits {
83+
class MPICHImplTraits : public MPIImplTraits {
5484
static constexpr int MPI_FLOAT = 0x4c00040a;
5585
static constexpr int MPI_DOUBLE = 0x4c00080b;
5686
static constexpr int MPI_INT8_T = 0x4c000137;
@@ -62,20 +92,20 @@ struct MPICHImplTraits {
6292
static constexpr int MPI_UINT32_T = 0x4c00043d;
6393
static constexpr int MPI_UINT64_T = 0x4c00083e;
6494

65-
static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
66-
const mlir::Location loc,
67-
mlir::ConversionPatternRewriter &rewriter) {
95+
public:
96+
using MPIImplTraits::MPIImplTraits;
97+
98+
Value getCommWorld(const Location loc,
99+
ConversionPatternRewriter &rewriter) override {
68100
static const int MPI_COMM_WORLD = 0x44000000;
69-
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
70-
MPI_COMM_WORLD);
101+
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
102+
MPI_COMM_WORLD);
71103
}
72104

73-
static intptr_t getStatusIgnore() { return 1; }
105+
intptr_t getStatusIgnore() override { return 1; }
74106

75-
static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
76-
const mlir::Location loc,
77-
mlir::ConversionPatternRewriter &rewriter,
78-
mlir::Type type) {
107+
Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
108+
Type type) override {
79109
int32_t mtype = 0;
80110
if (type.isF32())
81111
mtype = MPI_FLOAT;
@@ -99,53 +129,50 @@ struct MPICHImplTraits {
99129
mtype = MPI_UINT8_T;
100130
else
101131
assert(false && "unsupported type");
102-
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
103-
mtype);
132+
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
104133
}
105134
};
106135

107136
//===----------------------------------------------------------------------===//
108137
// Implementation details for OpenMPI
109138
//===----------------------------------------------------------------------===//
110-
struct OMPIImplTraits {
111-
112-
static mlir::LLVM::GlobalOp
113-
getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc,
114-
mlir::ConversionPatternRewriter &rewriter,
115-
mlir::StringRef name,
116-
mlir::LLVM::LLVMStructType type) {
117-
118-
return getOrDefineGlobal<mlir::LLVM::GlobalOp>(
119-
moduleOp, loc, rewriter, name, type, /*isConstant=*/false,
120-
mlir::LLVM::Linkage::External, name,
121-
/*value=*/mlir::Attribute(), /*alignment=*/0, 0);
139+
class OMPIImplTraits : public MPIImplTraits {
140+
LLVM::GlobalOp getOrDefineExternalStruct(const Location loc,
141+
ConversionPatternRewriter &rewriter,
142+
StringRef name,
143+
LLVM::LLVMStructType type) {
144+
145+
return getOrDefineGlobal<LLVM::GlobalOp>(
146+
getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false,
147+
LLVM::Linkage::External, name,
148+
/*value=*/Attribute(), /*alignment=*/0, 0);
122149
}
123150

124-
static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
125-
const mlir::Location loc,
126-
mlir::ConversionPatternRewriter &rewriter) {
151+
public:
152+
using MPIImplTraits::MPIImplTraits;
153+
154+
Value getCommWorld(const Location loc,
155+
ConversionPatternRewriter &rewriter) override {
127156
auto context = rewriter.getContext();
128157
// get external opaque struct pointer type
129158
auto commStructT =
130-
mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
131-
mlir::StringRef name = "ompi_mpi_comm_world";
159+
LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
160+
StringRef name = "ompi_mpi_comm_world";
132161

133162
// make sure global op definition exists
134-
(void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT);
163+
(void)getOrDefineExternalStruct(loc, rewriter, name, commStructT);
135164

136165
// get address of symbol
137-
return rewriter.create<mlir::LLVM::AddressOfOp>(
138-
loc, mlir::LLVM::LLVMPointerType::get(context),
139-
mlir::SymbolRefAttr::get(context, name));
166+
return rewriter.create<LLVM::AddressOfOp>(
167+
loc, LLVM::LLVMPointerType::get(context),
168+
SymbolRefAttr::get(context, name));
140169
}
141170

142-
static intptr_t getStatusIgnore() { return 0; }
171+
intptr_t getStatusIgnore() override { return 0; }
143172

144-
static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
145-
const mlir::Location loc,
146-
mlir::ConversionPatternRewriter &rewriter,
147-
mlir::Type type) {
148-
mlir::StringRef mtype;
173+
Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
174+
Type type) override {
175+
StringRef mtype;
149176
if (type.isF32())
150177
mtype = "ompi_mpi_float";
151178
else if (type.isF64())
@@ -171,67 +198,29 @@ struct OMPIImplTraits {
171198

172199
auto context = rewriter.getContext();
173200
// get external opaque struct pointer type
174-
auto commStructT = mlir::LLVM::LLVMStructType::getOpaque(
175-
"ompi_predefined_datatype_t", context);
201+
auto commStructT =
202+
LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
176203
// make sure global op definition exists
177-
(void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype,
178-
commStructT);
204+
(void)getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
179205
// get address of symbol
180-
return rewriter.create<mlir::LLVM::AddressOfOp>(
181-
loc, mlir::LLVM::LLVMPointerType::get(context),
182-
mlir::SymbolRefAttr::get(context, mtype));
206+
return rewriter.create<LLVM::AddressOfOp>(
207+
loc, LLVM::LLVMPointerType::get(context),
208+
SymbolRefAttr::get(context, mtype));
183209
}
184210
};
185211

186-
/// When lowering the mpi dialect to functions calls certain details
187-
/// differ between various MPI implementations. This class will provide
188-
/// these in a generic way, depending on the MPI implementation that got
189-
/// selected by the DLTI attribute on the module.
190-
struct MPIImplTraits {
191-
enum MPIImpl { MPICH, OMPI };
192-
193-
/// Gets the MPI implementation from a DLTI attribute on the module.
194-
/// Defaults to MPICH (and ABI compatible).
195-
static MPIImpl getMPIImpl(mlir::ModuleOp &moduleOp) {
196-
auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
197-
if (failed(attr))
198-
return MPICH;
199-
auto strAttr = dyn_cast<StringAttr>(attr.value());
200-
if (strAttr && strAttr.getValue() == "OpenMPI")
201-
return OMPI;
202-
if (!strAttr || strAttr.getValue() != "MPICH")
203-
moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
204-
<< strAttr.getValue() << "), defaulting to MPICH";
205-
return MPICH;
206-
}
207-
208-
/// Gets or creates MPI_COMM_WORLD as a mlir::Value.
209-
static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
210-
const mlir::Location loc,
211-
mlir::ConversionPatternRewriter &rewriter) {
212-
if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI)
213-
return OMPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
214-
return MPICHImplTraits::getCommWorld(moduleOp, loc, rewriter);
215-
}
216-
217-
/// Get the MPI_STATUS_IGNORE value (typically a pointer type).
218-
static intptr_t getStatusIgnore(mlir::ModuleOp &moduleOp) {
219-
if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI)
220-
return OMPIImplTraits::getStatusIgnore();
221-
return MPICHImplTraits::getStatusIgnore();
222-
}
223-
224-
/// get/create MPI datatype as a mlir::Value which corresponds to the given
225-
/// mlir::Type
226-
static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
227-
const mlir::Location loc,
228-
mlir::ConversionPatternRewriter &rewriter,
229-
mlir::Type type) {
230-
if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI)
231-
return OMPIImplTraits::getDataType(moduleOp, loc, rewriter, type);
232-
return MPICHImplTraits::getDataType(moduleOp, loc, rewriter, type);
233-
}
234-
};
212+
std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
213+
auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
214+
if (failed(attr))
215+
return std::make_unique<MPICHImplTraits>(moduleOp);
216+
auto strAttr = dyn_cast<StringAttr>(attr.value());
217+
if (strAttr && strAttr.getValue() == "OpenMPI")
218+
return std::make_unique<OMPIImplTraits>(moduleOp);
219+
if (!strAttr || strAttr.getValue() != "MPICH")
220+
moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
221+
<< strAttr.getValue() << "), defaulting to MPICH";
222+
return std::make_unique<MPICHImplTraits>(moduleOp);
223+
}
235224

236225
//===----------------------------------------------------------------------===//
237226
// InitOpLowering
@@ -320,8 +309,9 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
320309
// grab a reference to the global module op:
321310
auto moduleOp = op->getParentOfType<ModuleOp>();
322311

312+
auto mpiTraits = MPIImplTraits::get(moduleOp);
323313
// get MPI_COMM_WORLD
324-
Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
314+
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
325315

326316
// LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
327317
auto rankFuncType =
@@ -387,9 +377,9 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
387377
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
388378
ArrayRef<int64_t>{3, 0});
389379
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
390-
Value dataType =
391-
MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
392-
Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
380+
auto mpiTraits = MPIImplTraits::get(moduleOp);
381+
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
382+
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
393383

394384
// LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
395385
// tag, comm)`
@@ -446,11 +436,11 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
446436
Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
447437
ArrayRef<int64_t>{3, 0});
448438
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
449-
Value dataType =
450-
MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
451-
Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
439+
auto mpiTraits = MPIImplTraits::get(moduleOp);
440+
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
441+
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
452442
Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
453-
loc, i64, MPIImplTraits::getStatusIgnore(moduleOp));
443+
loc, i64, mpiTraits->getStatusIgnore());
454444
statusIgnore =
455445
rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
456446

@@ -498,13 +488,13 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
498488
// Pattern Population
499489
//===----------------------------------------------------------------------===//
500490

501-
void mlir::mpi::populateMPIToLLVMConversionPatterns(
502-
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
491+
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
492+
RewritePatternSet &patterns) {
503493
patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
504494
SendOpLowering, RecvOpLowering>(converter);
505495
}
506496

507-
void mlir::mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
497+
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
508498
registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
509499
dialect->addInterfaces<FuncToLLVMDialectInterface>();
510500
});

0 commit comments

Comments
 (0)