1313#include " flang/Optimizer/Builder/Complex.h"
1414#include " flang/Optimizer/Builder/FIRBuilder.h"
1515#include " flang/Optimizer/Builder/HLFIRTools.h"
16+ #include " flang/Optimizer/Builder/IntrinsicCall.h"
1617#include " flang/Optimizer/Dialect/FIRDialect.h"
1718#include " flang/Optimizer/HLFIR/HLFIRDialect.h"
1819#include " flang/Optimizer/HLFIR/HLFIROps.h"
@@ -331,6 +332,107 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
331332 }
332333};
333334
335+ class CShiftAsElementalConversion
336+ : public mlir::OpRewritePattern<hlfir::CShiftOp> {
337+ public:
338+ using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
339+
340+ explicit CShiftAsElementalConversion (mlir::MLIRContext *ctx)
341+ : OpRewritePattern(ctx) {
342+ setHasBoundedRewriteRecursion ();
343+ }
344+
345+ llvm::LogicalResult
346+ matchAndRewrite (hlfir::CShiftOp cshift,
347+ mlir::PatternRewriter &rewriter) const override {
348+ using Fortran::common::maxRank;
349+
350+ mlir::Location loc = cshift.getLoc ();
351+ fir::FirOpBuilder builder{rewriter, cshift.getOperation ()};
352+ hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType ());
353+ assert (expr && " expected an expression type for the result of hlfir.sum" );
354+ mlir::Type elementType = expr.getElementType ();
355+ hlfir::Entity array = hlfir::Entity{cshift.getArray ()};
356+ mlir::Value arrayShape = hlfir::genShape (loc, builder, array);
357+ llvm::SmallVector<mlir::Value> arrayExtents =
358+ hlfir::getExplicitExtentsFromShape (arrayShape, builder);
359+ unsigned arrayRank = expr.getRank ();
360+ llvm::SmallVector<mlir::Value, 1 > typeParams;
361+ hlfir::genLengthParameters (loc, builder, array, typeParams);
362+ hlfir::Entity shift = hlfir::Entity{cshift.getShift ()};
363+ // The new index computation involves MODULO, which is not implemented
364+ // for IndexType, so use I64 instead.
365+ mlir::Type calcType = builder.getI64Type ();
366+
367+ mlir::Value one = builder.createIntegerConstant (loc, calcType, 1 );
368+ mlir::Value shiftVal;
369+ if (shift.isScalar ()) {
370+ shiftVal = hlfir::loadTrivialScalar (loc, builder, shift);
371+ shiftVal = builder.createConvert (loc, calcType, shiftVal);
372+ }
373+
374+ int64_t dimVal = 1 ;
375+ if (arrayRank == 1 ) {
376+ // When it is a 1D CSHIFT, we may assume that the DIM argument
377+ // (whether it is present or absent) is equal to 1, otherwise,
378+ // the program is illegal.
379+ assert (shiftVal && " SHIFT must be scalar" );
380+ } else {
381+ if (mlir::Value dim = cshift.getDim ())
382+ dimVal = fir::getIntIfConstant (dim).value_or (0 );
383+ assert (dimVal > 0 && dimVal <= arrayRank &&
384+ " DIM must be present and a positive constant not exceeding "
385+ " the array's rank" );
386+ }
387+
388+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
389+ mlir::ValueRange inputIndices) -> hlfir::Entity {
390+ llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
391+ if (!shift.isScalar ()) {
392+ // When the array is not a vector, section
393+ // (s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)
394+ // of the result has a value equal to:
395+ // CSHIFT(ARRAY(s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)),
396+ // SH, 1),
397+ // where SH is either SHIFT (if scalar) or
398+ // SHIFT(s(1), s(2), ..., s(dim-1), s(dim+1), ..., s(n)).
399+ llvm::SmallVector<mlir::Value, maxRank> shiftIndices{indices};
400+ shiftIndices.erase (shiftIndices.begin () + dimVal - 1 );
401+ hlfir::Entity shiftElement =
402+ hlfir::getElementAt (loc, builder, shift, shiftIndices);
403+ shiftVal = hlfir::loadTrivialScalar (loc, builder, shiftElement);
404+ shiftVal = builder.createConvert (loc, calcType, shiftVal);
405+ }
406+
407+ // Element i of the result (1-based) is element
408+ // 'MODULO(i + SH - 1, SIZE(ARRAY)) + 1' (1-based) of the original
409+ // ARRAY (or its section, when ARRAY is not a vector).
410+ mlir::Value index =
411+ builder.createConvert (loc, calcType, inputIndices[dimVal - 1 ]);
412+ mlir::Value extent = arrayExtents[dimVal - 1 ];
413+ mlir::Value newIndex =
414+ builder.create <mlir::arith::AddIOp>(loc, index, shiftVal);
415+ newIndex = builder.create <mlir::arith::SubIOp>(loc, newIndex, one);
416+ newIndex = fir::IntrinsicLibrary{builder, loc}.genModulo (
417+ calcType, {newIndex, builder.createConvert (loc, calcType, extent)});
418+ newIndex = builder.create <mlir::arith::AddIOp>(loc, newIndex, one);
419+ newIndex = builder.createConvert (loc, builder.getIndexType (), newIndex);
420+
421+ indices[dimVal - 1 ] = newIndex;
422+ hlfir::Entity element = hlfir::getElementAt (loc, builder, array, indices);
423+ return hlfir::loadTrivialScalar (loc, builder, element);
424+ };
425+
426+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp (
427+ loc, builder, elementType, arrayShape, typeParams, genKernel,
428+ /* isUnordered=*/ true ,
429+ array.isPolymorphic () ? static_cast <mlir::Value>(array) : nullptr ,
430+ cshift.getResult ().getType ());
431+ rewriter.replaceOp (cshift, elementalOp);
432+ return mlir::success ();
433+ }
434+ };
435+
334436class SimplifyHLFIRIntrinsics
335437 : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
336438public:
@@ -339,6 +441,7 @@ class SimplifyHLFIRIntrinsics
339441 mlir::RewritePatternSet patterns (context);
340442 patterns.insert <TransposeAsElementalConversion>(context);
341443 patterns.insert <SumAsElementalConversion>(context);
444+ patterns.insert <CShiftAsElementalConversion>(context);
342445 mlir::ConversionTarget target (*context);
343446 // don't transform transpose of polymorphic arrays (not currently supported
344447 // by hlfir.elemental)
@@ -375,6 +478,24 @@ class SimplifyHLFIRIntrinsics
375478 }
376479 return true ;
377480 });
481+ target.addDynamicallyLegalOp <hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
482+ unsigned resultRank = hlfir::Entity{cshift}.getRank ();
483+ if (resultRank == 1 )
484+ return false ;
485+
486+ mlir::Value dim = cshift.getDim ();
487+ if (!dim)
488+ return false ;
489+
490+ // If DIM is present, then it must be constant to please
491+ // the conversion. In addition, ignore cases with
492+ // illegal DIM values.
493+ if (auto dimVal = fir::getIntIfConstant (dim))
494+ if (*dimVal > 0 && *dimVal <= resultRank)
495+ return false ;
496+
497+ return true ;
498+ });
378499 target.markUnknownOpDynamicallyLegal (
379500 [](mlir::Operation *) { return true ; });
380501 if (mlir::failed (mlir::applyFullConversion (getOperation (), target,
0 commit comments