@@ -47,6 +47,22 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
4747 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
4848}
4949
50+ std::pair<Value, Value> getRawPtrAndSize (const Location loc,
51+ ConversionPatternRewriter &rewriter,
52+ Value memRef, Type elType) {
53+ Type ptrType = LLVM::LLVMPointerType::get (rewriter.getContext ());
54+ Value dataPtr =
55+ rewriter.create <LLVM::ExtractValueOp>(loc, ptrType, memRef, 1 );
56+ Value offset = rewriter.create <LLVM::ExtractValueOp>(
57+ loc, rewriter.getI64Type (), memRef, 2 );
58+ Value resPtr =
59+ rewriter.create <LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
60+ Value size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
61+ ArrayRef<int64_t >{3 , 0 });
62+ size = rewriter.create <LLVM::TruncOp>(loc, rewriter.getI32Type (), size);
63+ return {resPtr, size};
64+ }
65+
5066// / When lowering the mpi dialect to functions calls certain details
5167// / differ between various MPI implementations. This class will provide
5268// / these in a generic way, depending on the MPI implementation that got
@@ -77,6 +93,12 @@ class MPIImplTraits {
7793 // / type.
7894 virtual Value getDataType (const Location loc,
7995 ConversionPatternRewriter &rewriter, Type type) = 0;
96+
97+ // / Gets or creates an MPI_Op value which corresponds to the given
98+ // / enum value.
99+ virtual Value getMPIOp (const Location loc,
100+ ConversionPatternRewriter &rewriter,
101+ mpi::MPI_OpClassEnum opAttr) = 0;
80102};
81103
82104// ===----------------------------------------------------------------------===//
@@ -94,6 +116,20 @@ class MPICHImplTraits : public MPIImplTraits {
94116 static constexpr int MPI_UINT16_T = 0x4c00023c ;
95117 static constexpr int MPI_UINT32_T = 0x4c00043d ;
96118 static constexpr int MPI_UINT64_T = 0x4c00083e ;
119+ static constexpr int MPI_MAX = 0x58000001 ;
120+ static constexpr int MPI_MIN = 0x58000002 ;
121+ static constexpr int MPI_SUM = 0x58000003 ;
122+ static constexpr int MPI_PROD = 0x58000004 ;
123+ static constexpr int MPI_LAND = 0x58000005 ;
124+ static constexpr int MPI_BAND = 0x58000006 ;
125+ static constexpr int MPI_LOR = 0x58000007 ;
126+ static constexpr int MPI_BOR = 0x58000008 ;
127+ static constexpr int MPI_LXOR = 0x58000009 ;
128+ static constexpr int MPI_BXOR = 0x5800000a ;
129+ static constexpr int MPI_MINLOC = 0x5800000b ;
130+ static constexpr int MPI_MAXLOC = 0x5800000c ;
131+ static constexpr int MPI_REPLACE = 0x5800000d ;
132+ static constexpr int MPI_NO_OP = 0x5800000e ;
97133
98134public:
99135 using MPIImplTraits::MPIImplTraits;
@@ -136,6 +172,56 @@ class MPICHImplTraits : public MPIImplTraits {
136172 assert (false && " unsupported type" );
137173 return rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI32Type (), mtype);
138174 }
175+
176+ Value getMPIOp (const Location loc, ConversionPatternRewriter &rewriter,
177+ mpi::MPI_OpClassEnum opAttr) override {
178+ int32_t op = MPI_NO_OP;
179+ switch (opAttr) {
180+ case mpi::MPI_OpClassEnum::MPI_OP_NULL:
181+ op = MPI_NO_OP;
182+ break ;
183+ case mpi::MPI_OpClassEnum::MPI_MAX:
184+ op = MPI_MAX;
185+ break ;
186+ case mpi::MPI_OpClassEnum::MPI_MIN:
187+ op = MPI_MIN;
188+ break ;
189+ case mpi::MPI_OpClassEnum::MPI_SUM:
190+ op = MPI_SUM;
191+ break ;
192+ case mpi::MPI_OpClassEnum::MPI_PROD:
193+ op = MPI_PROD;
194+ break ;
195+ case mpi::MPI_OpClassEnum::MPI_LAND:
196+ op = MPI_LAND;
197+ break ;
198+ case mpi::MPI_OpClassEnum::MPI_BAND:
199+ op = MPI_BAND;
200+ break ;
201+ case mpi::MPI_OpClassEnum::MPI_LOR:
202+ op = MPI_LOR;
203+ break ;
204+ case mpi::MPI_OpClassEnum::MPI_BOR:
205+ op = MPI_BOR;
206+ break ;
207+ case mpi::MPI_OpClassEnum::MPI_LXOR:
208+ op = MPI_LXOR;
209+ break ;
210+ case mpi::MPI_OpClassEnum::MPI_BXOR:
211+ op = MPI_BXOR;
212+ break ;
213+ case mpi::MPI_OpClassEnum::MPI_MINLOC:
214+ op = MPI_MINLOC;
215+ break ;
216+ case mpi::MPI_OpClassEnum::MPI_MAXLOC:
217+ op = MPI_MAXLOC;
218+ break ;
219+ case mpi::MPI_OpClassEnum::MPI_REPLACE:
220+ op = MPI_REPLACE;
221+ break ;
222+ }
223+ return rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI32Type (), op);
224+ }
139225};
140226
141227// ===----------------------------------------------------------------------===//
@@ -205,15 +291,74 @@ class OMPIImplTraits : public MPIImplTraits {
205291
206292 auto context = rewriter.getContext ();
207293 // get external opaque struct pointer type
208- auto commStructT =
294+ auto typeStructT =
209295 LLVM::LLVMStructType::getOpaque (" ompi_predefined_datatype_t" , context);
210296 // make sure global op definition exists
211- getOrDefineExternalStruct (loc, rewriter, mtype, commStructT );
297+ getOrDefineExternalStruct (loc, rewriter, mtype, typeStructT );
212298 // get address of symbol
213299 return rewriter.create <LLVM::AddressOfOp>(
214300 loc, LLVM::LLVMPointerType::get (context),
215301 SymbolRefAttr::get (context, mtype));
216302 }
303+
304+ Value getMPIOp (const Location loc, ConversionPatternRewriter &rewriter,
305+ mpi::MPI_OpClassEnum opAttr) override {
306+ StringRef op;
307+ switch (opAttr) {
308+ case mpi::MPI_OpClassEnum::MPI_OP_NULL:
309+ op = " ompi_mpi_no_op" ;
310+ break ;
311+ case mpi::MPI_OpClassEnum::MPI_MAX:
312+ op = " ompi_mpi_max" ;
313+ break ;
314+ case mpi::MPI_OpClassEnum::MPI_MIN:
315+ op = " ompi_mpi_min" ;
316+ break ;
317+ case mpi::MPI_OpClassEnum::MPI_SUM:
318+ op = " ompi_mpi_sum" ;
319+ break ;
320+ case mpi::MPI_OpClassEnum::MPI_PROD:
321+ op = " ompi_mpi_prod" ;
322+ break ;
323+ case mpi::MPI_OpClassEnum::MPI_LAND:
324+ op = " ompi_mpi_land" ;
325+ break ;
326+ case mpi::MPI_OpClassEnum::MPI_BAND:
327+ op = " ompi_mpi_band" ;
328+ break ;
329+ case mpi::MPI_OpClassEnum::MPI_LOR:
330+ op = " ompi_mpi_lor" ;
331+ break ;
332+ case mpi::MPI_OpClassEnum::MPI_BOR:
333+ op = " ompi_mpi_bor" ;
334+ break ;
335+ case mpi::MPI_OpClassEnum::MPI_LXOR:
336+ op = " ompi_mpi_lxor" ;
337+ break ;
338+ case mpi::MPI_OpClassEnum::MPI_BXOR:
339+ op = " ompi_mpi_bxor" ;
340+ break ;
341+ case mpi::MPI_OpClassEnum::MPI_MINLOC:
342+ op = " ompi_mpi_minloc" ;
343+ break ;
344+ case mpi::MPI_OpClassEnum::MPI_MAXLOC:
345+ op = " ompi_mpi_maxloc" ;
346+ break ;
347+ case mpi::MPI_OpClassEnum::MPI_REPLACE:
348+ op = " ompi_mpi_replace" ;
349+ break ;
350+ }
351+ auto context = rewriter.getContext ();
352+ // get external opaque struct pointer type
353+ auto opStructT =
354+ LLVM::LLVMStructType::getOpaque (" ompi_predefined_op_t" , context);
355+ // make sure global op definition exists
356+ getOrDefineExternalStruct (loc, rewriter, op, opStructT);
357+ // get address of symbol
358+ return rewriter.create <LLVM::AddressOfOp>(
359+ loc, LLVM::LLVMPointerType::get (context),
360+ SymbolRefAttr::get (context, op));
361+ }
217362};
218363
219364std::unique_ptr<MPIImplTraits> MPIImplTraits::get (ModuleOp &moduleOp) {
@@ -365,8 +510,6 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
365510 Location loc = op.getLoc ();
366511 MLIRContext *context = rewriter.getContext ();
367512 Type i32 = rewriter.getI32Type ();
368- Type i64 = rewriter.getI64Type ();
369- Value memRef = adaptor.getRef ();
370513 Type elemType = op.getRef ().getType ().getElementType ();
371514
372515 // ptrType `!llvm.ptr`
@@ -376,14 +519,8 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
376519 auto moduleOp = op->getParentOfType <ModuleOp>();
377520
378521 // get MPI_COMM_WORLD, dataType and pointer
379- Value dataPtr =
380- rewriter.create <LLVM::ExtractValueOp>(loc, ptrType, memRef, 1 );
381- Value offset = rewriter.create <LLVM::ExtractValueOp>(loc, i64 , memRef, 2 );
382- dataPtr =
383- rewriter.create <LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
384- Value size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
385- ArrayRef<int64_t >{3 , 0 });
386- size = rewriter.create <LLVM::TruncOp>(loc, i32 , size);
522+ auto [dataPtr, size] =
523+ getRawPtrAndSize (loc, rewriter, adaptor.getRef (), elemType);
387524 auto mpiTraits = MPIImplTraits::get (moduleOp);
388525 Value dataType = mpiTraits->getDataType (loc, rewriter, elemType);
389526 Value commWorld = mpiTraits->getCommWorld (loc, rewriter);
@@ -425,7 +562,6 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
425562 MLIRContext *context = rewriter.getContext ();
426563 Type i32 = rewriter.getI32Type ();
427564 Type i64 = rewriter.getI64Type ();
428- Value memRef = adaptor.getRef ();
429565 Type elemType = op.getRef ().getType ().getElementType ();
430566
431567 // ptrType `!llvm.ptr`
@@ -435,14 +571,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
435571 auto moduleOp = op->getParentOfType <ModuleOp>();
436572
437573 // get MPI_COMM_WORLD, dataType, status_ignore and pointer
438- Value dataPtr =
439- rewriter.create <LLVM::ExtractValueOp>(loc, ptrType, memRef, 1 );
440- Value offset = rewriter.create <LLVM::ExtractValueOp>(loc, i64 , memRef, 2 );
441- dataPtr =
442- rewriter.create <LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
443- Value size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
444- ArrayRef<int64_t >{3 , 0 });
445- size = rewriter.create <LLVM::TruncOp>(loc, i32 , size);
574+ auto [dataPtr, size] =
575+ getRawPtrAndSize (loc, rewriter, adaptor.getRef (), elemType);
446576 auto mpiTraits = MPIImplTraits::get (moduleOp);
447577 Value dataType = mpiTraits->getDataType (loc, rewriter, elemType);
448578 Value commWorld = mpiTraits->getCommWorld (loc, rewriter);
@@ -474,6 +604,55 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
474604 }
475605};
476606
607+ // ===----------------------------------------------------------------------===//
608+ // AllReduceOpLowering
609+ // ===----------------------------------------------------------------------===//
610+
611+ struct AllReduceOpLowering : public ConvertOpToLLVMPattern <mpi::AllReduceOp> {
612+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
613+
614+ LogicalResult
615+ matchAndRewrite (mpi::AllReduceOp op, OpAdaptor adaptor,
616+ ConversionPatternRewriter &rewriter) const override {
617+ Location loc = op.getLoc ();
618+ MLIRContext *context = rewriter.getContext ();
619+ Type i32 = rewriter.getI32Type ();
620+ Type elemType = op.getSendbuf ().getType ().getElementType ();
621+
622+ // ptrType `!llvm.ptr`
623+ Type ptrType = LLVM::LLVMPointerType::get (context);
624+ auto moduleOp = op->getParentOfType <ModuleOp>();
625+ auto mpiTraits = MPIImplTraits::get (moduleOp);
626+ auto [sendPtr, sendSize] =
627+ getRawPtrAndSize (loc, rewriter, adaptor.getSendbuf (), elemType);
628+ auto [recvPtr, recvSize] =
629+ getRawPtrAndSize (loc, rewriter, adaptor.getRecvbuf (), elemType);
630+ Value dataType = mpiTraits->getDataType (loc, rewriter, elemType);
631+ Value mpiOp = mpiTraits->getMPIOp (loc, rewriter, op.getOp ());
632+ Value commWorld = mpiTraits->getCommWorld (loc, rewriter);
633+ // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
634+ // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
635+ auto funcType = LLVM::LLVMFunctionType::get (
636+ i32 , {ptrType, ptrType, i32 , dataType.getType (), mpiOp.getType (),
637+ commWorld.getType ()});
638+ // get or create function declaration:
639+ LLVM::LLVMFuncOp funcDecl =
640+ getOrDefineFunction (moduleOp, loc, rewriter, " MPI_Allreduce" , funcType);
641+
642+ // replace op with function call
643+ auto funcCall = rewriter.create <LLVM::CallOp>(
644+ loc, funcDecl,
645+ ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
646+
647+ if (op.getRetval ())
648+ rewriter.replaceOp (op, funcCall.getResult ());
649+ else
650+ rewriter.eraseOp (op);
651+
652+ return success ();
653+ }
654+ };
655+
477656// ===----------------------------------------------------------------------===//
478657// ConvertToLLVMPatternInterface implementation
479658// ===----------------------------------------------------------------------===//
@@ -498,7 +677,7 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
498677void mpi::populateMPIToLLVMConversionPatterns (LLVMTypeConverter &converter,
499678 RewritePatternSet &patterns) {
500679 patterns.add <CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
501- SendOpLowering, RecvOpLowering>(converter);
680+ SendOpLowering, RecvOpLowering, AllReduceOpLowering >(converter);
502681}
503682
504683void mpi::registerConvertMPIToLLVMInterface (DialectRegistry ®istry) {
0 commit comments