@@ -469,33 +469,49 @@ struct MatmulTransposeOpConversion
469
469
}
470
470
};
471
471
472
- class CShiftOpConversion : public HlfirIntrinsicConversion <hlfir::CShiftOp> {
473
- using HlfirIntrinsicConversion<hlfir::CShiftOp>::HlfirIntrinsicConversion;
472
+ // A converter for hlfir.cshift and hlfir.eoshift.
473
+ template <typename T>
474
+ class ArrayShiftOpConversion : public HlfirIntrinsicConversion <T> {
475
+ using HlfirIntrinsicConversion<T>::HlfirIntrinsicConversion;
476
+ using HlfirIntrinsicConversion<T>::lowerArguments;
477
+ using HlfirIntrinsicConversion<T>::processReturnValue;
478
+ using typename HlfirIntrinsicConversion<T>::IntrinsicArgument;
474
479
475
480
llvm::LogicalResult
476
- matchAndRewrite (hlfir::CShiftOp cshift,
477
- mlir::PatternRewriter &rewriter) const override {
478
- fir::FirOpBuilder builder{rewriter, cshift.getOperation ()};
479
- const mlir::Location &loc = cshift->getLoc ();
481
+ matchAndRewrite (T op, mlir::PatternRewriter &rewriter) const override {
482
+ fir::FirOpBuilder builder{rewriter, op.getOperation ()};
483
+ const mlir::Location &loc = op->getLoc ();
480
484
481
- llvm::SmallVector<IntrinsicArgument, 3 > inArgs;
482
- mlir::Value array = cshift.getArray ();
485
+ llvm::SmallVector<IntrinsicArgument, 4 > inArgs;
486
+ llvm::StringRef intrinsicName{[]() {
487
+ if constexpr (std::is_same_v<T, hlfir::EOShiftOp>)
488
+ return " eoshift" ;
489
+ else if constexpr (std::is_same_v<T, hlfir::CShiftOp>)
490
+ return " cshift" ;
491
+ else
492
+ llvm_unreachable (" unsupported array shift" );
493
+ }()};
494
+
495
+ mlir::Value array = op.getArray ();
483
496
inArgs.push_back ({array, array.getType ()});
484
- mlir::Value shift = cshift .getShift ();
497
+ mlir::Value shift = op .getShift ();
485
498
inArgs.push_back ({shift, shift.getType ()});
486
- inArgs.push_back ({cshift.getDim (), builder.getI32Type ()});
499
+ if constexpr (std::is_same_v<T, hlfir::EOShiftOp>) {
500
+ mlir::Value boundary = op.getBoundary ();
501
+ inArgs.push_back ({boundary, boundary ? boundary.getType () : nullptr });
502
+ }
503
+ inArgs.push_back ({op.getDim (), builder.getI32Type ()});
487
504
488
- auto *argLowering = fir::getIntrinsicArgumentLowering (" cshift " );
505
+ auto *argLowering = fir::getIntrinsicArgumentLowering (intrinsicName );
489
506
llvm::SmallVector<fir::ExtendedValue, 3 > args =
490
- lowerArguments (cshift , inArgs, rewriter, argLowering);
507
+ lowerArguments (op , inArgs, rewriter, argLowering);
491
508
492
- mlir::Type scalarResultType =
493
- hlfir::getFortranElementType (cshift.getType ());
509
+ mlir::Type scalarResultType = hlfir::getFortranElementType (op.getType ());
494
510
495
- auto [resultExv, mustBeFreed] =
496
- fir::genIntrinsicCall ( builder, loc, " cshift " , scalarResultType, args);
511
+ auto [resultExv, mustBeFreed] = fir::genIntrinsicCall (
512
+ builder, loc, intrinsicName , scalarResultType, args);
497
513
498
- processReturnValue (cshift , resultExv, mustBeFreed, builder, rewriter);
514
+ processReturnValue (op , resultExv, mustBeFreed, builder, rewriter);
499
515
return mlir::success ();
500
516
}
501
517
};
@@ -547,7 +563,8 @@ class LowerHLFIRIntrinsics
547
563
AnyOpConversion, SumOpConversion, ProductOpConversion,
548
564
TransposeOpConversion, CountOpConversion, DotProductOpConversion,
549
565
MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion,
550
- MaxlocOpConversion, CShiftOpConversion, ReshapeOpConversion>(context);
566
+ MaxlocOpConversion, ArrayShiftOpConversion<hlfir::CShiftOp>,
567
+ ArrayShiftOpConversion<hlfir::EOShiftOp>, ReshapeOpConversion>(context);
551
568
552
569
// While conceptually this pass is performing dialect conversion, we use
553
570
// pattern rewrites here instead of dialect conversion because this pass
0 commit comments