@@ -49,17 +49,18 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
4949// ===----------------------------------------------------------------------===//
5050// Implementation details for MPICH ABI compatible MPI implementations
5151// ===----------------------------------------------------------------------===//
52+
5253struct MPICHImplTraits {
5354 static constexpr int MPI_FLOAT = 0x4c00040a ;
54- static const int MPI_DOUBLE = 0x4c00080b ;
55- static const int MPI_INT8_T = 0x4c000137 ;
56- static const int MPI_INT16_T = 0x4c000238 ;
57- static const int MPI_INT32_T = 0x4c000439 ;
58- static const int MPI_INT64_T = 0x4c00083a ;
59- static const int MPI_UINT8_T = 0x4c00013b ;
60- static const int MPI_UINT16_T = 0x4c00023c ;
61- static const int MPI_UINT32_T = 0x4c00043d ;
62- static const int MPI_UINT64_T = 0x4c00083e ;
55+ static constexpr int MPI_DOUBLE = 0x4c00080b ;
56+ static constexpr int MPI_INT8_T = 0x4c000137 ;
57+ static constexpr int MPI_INT16_T = 0x4c000238 ;
58+ static constexpr int MPI_INT32_T = 0x4c000439 ;
59+ static constexpr int MPI_INT64_T = 0x4c00083a ;
60+ static constexpr int MPI_UINT8_T = 0x4c00013b ;
61+ static constexpr int MPI_UINT16_T = 0x4c00023c ;
62+ static constexpr int MPI_UINT32_T = 0x4c00043d ;
63+ static constexpr int MPI_UINT64_T = 0x4c00083e ;
6364
6465 static mlir::Value getCommWorld (mlir::ModuleOp &moduleOp,
6566 const mlir::Location loc,
@@ -196,43 +197,38 @@ struct MPIImplTraits {
196197 if (failed (attr))
197198 return MPICH;
198199 auto strAttr = dyn_cast<StringAttr>(attr.value ());
199- if (strAttr && strAttr.getValue () == " OpenMPI" ) {
200+ if (strAttr && strAttr.getValue () == " OpenMPI" )
200201 return OMPI;
201- }
202- if (!strAttr || strAttr.getValue () != " MPICH" ) {
203- moduleOp.emitWarning () << " Unknown \" MPI:Implementation\" value in DLTI (" << strAttr.getValue () << " ), "
204- " defaulting to MPICH" );
205- }
202+ if (!strAttr || strAttr.getValue () != " MPICH" )
203+ moduleOp.emitWarning () << " Unknown \" MPI:Implementation\" value in DLTI ("
204+ << strAttr.getValue () << " ), defaulting to MPICH" ;
206205 return MPICH;
207206 }
208207
209208 // / Gets or creates MPI_COMM_WORLD as a mlir::Value.
210209 static mlir::Value getCommWorld (mlir::ModuleOp &moduleOp,
211210 const mlir::Location loc,
212211 mlir::ConversionPatternRewriter &rewriter) {
213- if (MPIImplTraits::getMPIImpl (moduleOp) == OMPI) {
212+ if (MPIImplTraits::getMPIImpl (moduleOp) == OMPI)
214213 return OMPIImplTraits::getCommWorld (moduleOp, loc, rewriter);
215- }
216214 return MPICHImplTraits::getCommWorld (moduleOp, loc, rewriter);
217215 }
218216
219- // Get the MPI_STATUS_IGNORE value (typically a pointer type).
217+ // / Get the MPI_STATUS_IGNORE value (typically a pointer type).
220218 static intptr_t getStatusIgnore (mlir::ModuleOp &moduleOp) {
221- if (MPIImplTraits::getMPIImpl (moduleOp) == OMPI) {
219+ if (MPIImplTraits::getMPIImpl (moduleOp) == OMPI)
222220 return OMPIImplTraits::getStatusIgnore ();
223- }
224221 return MPICHImplTraits::getStatusIgnore ();
225222 }
226223
227- // get/create MPI datatype as a mlir::Value which corresponds to the given
228- // mlir::Type
224+ // / get/create MPI datatype as a mlir::Value which corresponds to the given
225+ // / mlir::Type
229226 static mlir::Value getDataType (mlir::ModuleOp &moduleOp,
230227 const mlir::Location loc,
231228 mlir::ConversionPatternRewriter &rewriter,
232229 mlir::Type type) {
233- if (MPIImplTraits::getMPIImpl (moduleOp) == OMPI) {
230+ if (MPIImplTraits::getMPIImpl (moduleOp) == OMPI)
234231 return OMPIImplTraits::getDataType (moduleOp, loc, rewriter, type);
235- }
236232 return MPICHImplTraits::getDataType (moduleOp, loc, rewriter, type);
237233 }
238234};
@@ -347,9 +343,9 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
347343 // if retval is checked, replace uses of retval with the results from the
348344 // call op
349345 SmallVector<Value> replacements;
350- if (op.getRetval ()) {
346+ if (op.getRetval ())
351347 replacements.push_back (callOp.getResult ());
352- }
348+
353349 // replace all uses, then erase op
354350 replacements.push_back (loadedRank.getRes ());
355351 rewriter.replaceOp (op, replacements);
@@ -408,11 +404,10 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
408404 loc, funcDecl,
409405 ValueRange{dataPtr, size, dataType, adaptor.getDest (), adaptor.getTag (),
410406 commWorld});
411- if (op.getRetval ()) {
407+ if (op.getRetval ())
412408 rewriter.replaceOp (op, funcCall.getResult ());
413- } else {
409+ else
414410 rewriter.eraseOp (op);
415- }
416411
417412 return success ();
418413 }
@@ -473,11 +468,10 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
473468 loc, funcDecl,
474469 ValueRange{dataPtr, size, dataType, adaptor.getSource (),
475470 adaptor.getTag (), commWorld, statusIgnore});
476- if (op.getRetval ()) {
471+ if (op.getRetval ())
477472 rewriter.replaceOp (op, funcCall.getResult ());
478- } else {
473+ else
479474 rewriter.eraseOp (op);
480- }
481475
482476 return success ();
483477 }
0 commit comments