Skip to content

Commit 4c6afc7

Browse files
authored
[flang] Lower hlfir.eoshift to the runtime call. (#153107)
Straightforward lowering of hlfir.eoshift to the runtime call in LowerHLFIRIntrinsics pass.
1 parent e315455 commit 4c6afc7

File tree

2 files changed

+329
-18
lines changed

2 files changed

+329
-18
lines changed

flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -469,33 +469,49 @@ struct MatmulTransposeOpConversion
469469
}
470470
};
471471

472-
class CShiftOpConversion : public HlfirIntrinsicConversion<hlfir::CShiftOp> {
473-
using HlfirIntrinsicConversion<hlfir::CShiftOp>::HlfirIntrinsicConversion;
472+
// A converter for hlfir.cshift and hlfir.eoshift.
473+
template <typename T>
474+
class ArrayShiftOpConversion : public HlfirIntrinsicConversion<T> {
475+
using HlfirIntrinsicConversion<T>::HlfirIntrinsicConversion;
476+
using HlfirIntrinsicConversion<T>::lowerArguments;
477+
using HlfirIntrinsicConversion<T>::processReturnValue;
478+
using typename HlfirIntrinsicConversion<T>::IntrinsicArgument;
474479

475480
llvm::LogicalResult
476-
matchAndRewrite(hlfir::CShiftOp cshift,
477-
mlir::PatternRewriter &rewriter) const override {
478-
fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
479-
const mlir::Location &loc = cshift->getLoc();
481+
matchAndRewrite(T op, mlir::PatternRewriter &rewriter) const override {
482+
fir::FirOpBuilder builder{rewriter, op.getOperation()};
483+
const mlir::Location &loc = op->getLoc();
480484

481-
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
482-
mlir::Value array = cshift.getArray();
485+
llvm::SmallVector<IntrinsicArgument, 4> inArgs;
486+
llvm::StringRef intrinsicName{[]() {
487+
if constexpr (std::is_same_v<T, hlfir::EOShiftOp>)
488+
return "eoshift";
489+
else if constexpr (std::is_same_v<T, hlfir::CShiftOp>)
490+
return "cshift";
491+
else
492+
llvm_unreachable("unsupported array shift");
493+
}()};
494+
495+
mlir::Value array = op.getArray();
483496
inArgs.push_back({array, array.getType()});
484-
mlir::Value shift = cshift.getShift();
497+
mlir::Value shift = op.getShift();
485498
inArgs.push_back({shift, shift.getType()});
486-
inArgs.push_back({cshift.getDim(), builder.getI32Type()});
499+
if constexpr (std::is_same_v<T, hlfir::EOShiftOp>) {
500+
mlir::Value boundary = op.getBoundary();
501+
inArgs.push_back({boundary, boundary ? boundary.getType() : nullptr});
502+
}
503+
inArgs.push_back({op.getDim(), builder.getI32Type()});
487504

488-
auto *argLowering = fir::getIntrinsicArgumentLowering("cshift");
505+
auto *argLowering = fir::getIntrinsicArgumentLowering(intrinsicName);
489506
llvm::SmallVector<fir::ExtendedValue, 3> args =
490-
lowerArguments(cshift, inArgs, rewriter, argLowering);
507+
lowerArguments(op, inArgs, rewriter, argLowering);
491508

492-
mlir::Type scalarResultType =
493-
hlfir::getFortranElementType(cshift.getType());
509+
mlir::Type scalarResultType = hlfir::getFortranElementType(op.getType());
494510

495-
auto [resultExv, mustBeFreed] =
496-
fir::genIntrinsicCall(builder, loc, "cshift", scalarResultType, args);
511+
auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
512+
builder, loc, intrinsicName, scalarResultType, args);
497513

498-
processReturnValue(cshift, resultExv, mustBeFreed, builder, rewriter);
514+
processReturnValue(op, resultExv, mustBeFreed, builder, rewriter);
499515
return mlir::success();
500516
}
501517
};
@@ -547,7 +563,8 @@ class LowerHLFIRIntrinsics
547563
AnyOpConversion, SumOpConversion, ProductOpConversion,
548564
TransposeOpConversion, CountOpConversion, DotProductOpConversion,
549565
MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion,
550-
MaxlocOpConversion, CShiftOpConversion, ReshapeOpConversion>(context);
566+
MaxlocOpConversion, ArrayShiftOpConversion<hlfir::CShiftOp>,
567+
ArrayShiftOpConversion<hlfir::EOShiftOp>, ReshapeOpConversion>(context);
551568

552569
// While conceptually this pass is performing dialect conversion, we use
553570
// pattern rewrites here instead of dialect conversion because this pass

0 commit comments

Comments
 (0)