@@ -469,33 +469,49 @@ struct MatmulTransposeOpConversion
469469 }
470470};
471471
472- class CShiftOpConversion : public HlfirIntrinsicConversion <hlfir::CShiftOp> {
473- using HlfirIntrinsicConversion<hlfir::CShiftOp>::HlfirIntrinsicConversion;
472+ // A converter for hlfir.cshift and hlfir.eoshift.
473+ template <typename T>
474+ class ArrayShiftOpConversion : public HlfirIntrinsicConversion <T> {
475+ using HlfirIntrinsicConversion<T>::HlfirIntrinsicConversion;
476+ using HlfirIntrinsicConversion<T>::lowerArguments;
477+ using HlfirIntrinsicConversion<T>::processReturnValue;
478+ using typename HlfirIntrinsicConversion<T>::IntrinsicArgument;
474479
475480 llvm::LogicalResult
476- matchAndRewrite (hlfir::CShiftOp cshift,
477- mlir::PatternRewriter &rewriter) const override {
478- fir::FirOpBuilder builder{rewriter, cshift.getOperation ()};
479- const mlir::Location &loc = cshift->getLoc ();
481+ matchAndRewrite (T op, mlir::PatternRewriter &rewriter) const override {
482+ fir::FirOpBuilder builder{rewriter, op.getOperation ()};
483+ const mlir::Location &loc = op->getLoc ();
480484
481- llvm::SmallVector<IntrinsicArgument, 3 > inArgs;
482- mlir::Value array = cshift.getArray ();
485+ llvm::SmallVector<IntrinsicArgument, 4 > inArgs;
486+ llvm::StringRef intrinsicName{[]() {
487+ if constexpr (std::is_same_v<T, hlfir::EOShiftOp>)
488+ return " eoshift" ;
489+ else if constexpr (std::is_same_v<T, hlfir::CShiftOp>)
490+ return " cshift" ;
491+ else
492+ llvm_unreachable (" unsupported array shift" );
493+ }()};
494+
495+ mlir::Value array = op.getArray ();
483496 inArgs.push_back ({array, array.getType ()});
484- mlir::Value shift = cshift .getShift ();
497+ mlir::Value shift = op .getShift ();
485498 inArgs.push_back ({shift, shift.getType ()});
486- inArgs.push_back ({cshift.getDim (), builder.getI32Type ()});
499+ if constexpr (std::is_same_v<T, hlfir::EOShiftOp>) {
500+ mlir::Value boundary = op.getBoundary ();
501+ inArgs.push_back ({boundary, boundary ? boundary.getType () : nullptr });
502+ }
503+ inArgs.push_back ({op.getDim (), builder.getI32Type ()});
487504
488- auto *argLowering = fir::getIntrinsicArgumentLowering (" cshift " );
505+ auto *argLowering = fir::getIntrinsicArgumentLowering (intrinsicName );
489506 llvm::SmallVector<fir::ExtendedValue, 3 > args =
490- lowerArguments (cshift , inArgs, rewriter, argLowering);
507+ lowerArguments (op , inArgs, rewriter, argLowering);
491508
492- mlir::Type scalarResultType =
493- hlfir::getFortranElementType (cshift.getType ());
509+ mlir::Type scalarResultType = hlfir::getFortranElementType (op.getType ());
494510
495- auto [resultExv, mustBeFreed] =
496- fir::genIntrinsicCall ( builder, loc, " cshift " , scalarResultType, args);
511+ auto [resultExv, mustBeFreed] = fir::genIntrinsicCall (
512+ builder, loc, intrinsicName , scalarResultType, args);
497513
498- processReturnValue (cshift , resultExv, mustBeFreed, builder, rewriter);
514+ processReturnValue (op , resultExv, mustBeFreed, builder, rewriter);
499515 return mlir::success ();
500516 }
501517};
@@ -547,7 +563,8 @@ class LowerHLFIRIntrinsics
547563 AnyOpConversion, SumOpConversion, ProductOpConversion,
548564 TransposeOpConversion, CountOpConversion, DotProductOpConversion,
549565 MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion,
550- MaxlocOpConversion, CShiftOpConversion, ReshapeOpConversion>(context);
566+ MaxlocOpConversion, ArrayShiftOpConversion<hlfir::CShiftOp>,
567+ ArrayShiftOpConversion<hlfir::EOShiftOp>, ReshapeOpConversion>(context);
551568
552569 // While conceptually this pass is performing dialect conversion, we use
553570 // pattern rewrites here instead of dialect conversion because this pass
0 commit comments