1313#include " flang/Optimizer/Builder/Complex.h"
1414#include " flang/Optimizer/Builder/FIRBuilder.h"
1515#include " flang/Optimizer/Builder/HLFIRTools.h"
16+ #include " flang/Optimizer/Builder/IntrinsicCall.h"
1617#include " flang/Optimizer/Dialect/FIRDialect.h"
1718#include " flang/Optimizer/HLFIR/HLFIRDialect.h"
1819#include " flang/Optimizer/HLFIR/HLFIROps.h"
@@ -331,6 +332,108 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
331332 }
332333};
333334
335+ class CShiftAsElementalConversion
336+ : public mlir::OpRewritePattern<hlfir::CShiftOp> {
337+ public:
338+ using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
339+
340+ explicit CShiftAsElementalConversion (mlir::MLIRContext *ctx)
341+ : OpRewritePattern(ctx) {
342+ setHasBoundedRewriteRecursion ();
343+ }
344+
345+ llvm::LogicalResult
346+ matchAndRewrite (hlfir::CShiftOp cshift,
347+ mlir::PatternRewriter &rewriter) const override {
348+ using Fortran::common::maxRank;
349+
350+ mlir::Location loc = cshift.getLoc ();
351+ fir::FirOpBuilder builder{rewriter, cshift.getOperation ()};
352+ hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType ());
353+ assert (expr &&
354+ " expected an expression type for the result of hlfir.cshift" );
355+ mlir::Type elementType = expr.getElementType ();
356+ hlfir::Entity array = hlfir::Entity{cshift.getArray ()};
357+ mlir::Value arrayShape = hlfir::genShape (loc, builder, array);
358+ llvm::SmallVector<mlir::Value> arrayExtents =
359+ hlfir::getExplicitExtentsFromShape (arrayShape, builder);
360+ unsigned arrayRank = expr.getRank ();
361+ llvm::SmallVector<mlir::Value, 1 > typeParams;
362+ hlfir::genLengthParameters (loc, builder, array, typeParams);
363+ hlfir::Entity shift = hlfir::Entity{cshift.getShift ()};
364+ // The new index computation involves MODULO, which is not implemented
365+ // for IndexType, so use I64 instead.
366+ mlir::Type calcType = builder.getI64Type ();
367+
368+ mlir::Value one = builder.createIntegerConstant (loc, calcType, 1 );
369+ mlir::Value shiftVal;
370+ if (shift.isScalar ()) {
371+ shiftVal = hlfir::loadTrivialScalar (loc, builder, shift);
372+ shiftVal = builder.createConvert (loc, calcType, shiftVal);
373+ }
374+
375+ int64_t dimVal = 1 ;
376+ if (arrayRank == 1 ) {
377+ // When it is a 1D CSHIFT, we may assume that the DIM argument
378+ // (whether it is present or absent) is equal to 1, otherwise,
379+ // the program is illegal.
380+ assert (shiftVal && " SHIFT must be scalar" );
381+ } else {
382+ if (mlir::Value dim = cshift.getDim ())
383+ dimVal = fir::getIntIfConstant (dim).value_or (0 );
384+ assert (dimVal > 0 && dimVal <= arrayRank &&
385+ " DIM must be present and a positive constant not exceeding "
386+ " the array's rank" );
387+ }
388+
389+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
390+ mlir::ValueRange inputIndices) -> hlfir::Entity {
391+ llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
392+ if (!shift.isScalar ()) {
393+ // When the array is not a vector, section
394+ // (s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)
395+ // of the result has a value equal to:
396+ // CSHIFT(ARRAY(s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)),
397+ // SH, 1),
398+ // where SH is either SHIFT (if scalar) or
399+ // SHIFT(s(1), s(2), ..., s(dim-1), s(dim+1), ..., s(n)).
400+ llvm::SmallVector<mlir::Value, maxRank> shiftIndices{indices};
401+ shiftIndices.erase (shiftIndices.begin () + dimVal - 1 );
402+ hlfir::Entity shiftElement =
403+ hlfir::getElementAt (loc, builder, shift, shiftIndices);
404+ shiftVal = hlfir::loadTrivialScalar (loc, builder, shiftElement);
405+ shiftVal = builder.createConvert (loc, calcType, shiftVal);
406+ }
407+
408+ // Element i of the result (1-based) is element
409+ // 'MODULO(i + SH - 1, SIZE(ARRAY)) + 1' (1-based) of the original
410+ // ARRAY (or its section, when ARRAY is not a vector).
411+ mlir::Value index =
412+ builder.createConvert (loc, calcType, inputIndices[dimVal - 1 ]);
413+ mlir::Value extent = arrayExtents[dimVal - 1 ];
414+ mlir::Value newIndex =
415+ builder.create <mlir::arith::AddIOp>(loc, index, shiftVal);
416+ newIndex = builder.create <mlir::arith::SubIOp>(loc, newIndex, one);
417+ newIndex = fir::IntrinsicLibrary{builder, loc}.genModulo (
418+ calcType, {newIndex, builder.createConvert (loc, calcType, extent)});
419+ newIndex = builder.create <mlir::arith::AddIOp>(loc, newIndex, one);
420+ newIndex = builder.createConvert (loc, builder.getIndexType (), newIndex);
421+
422+ indices[dimVal - 1 ] = newIndex;
423+ hlfir::Entity element = hlfir::getElementAt (loc, builder, array, indices);
424+ return hlfir::loadTrivialScalar (loc, builder, element);
425+ };
426+
427+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp (
428+ loc, builder, elementType, arrayShape, typeParams, genKernel,
429+ /* isUnordered=*/ true ,
430+ array.isPolymorphic () ? static_cast <mlir::Value>(array) : nullptr ,
431+ cshift.getResult ().getType ());
432+ rewriter.replaceOp (cshift, elementalOp);
433+ return mlir::success ();
434+ }
435+ };
436+
334437class SimplifyHLFIRIntrinsics
335438 : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
336439public:
@@ -339,6 +442,7 @@ class SimplifyHLFIRIntrinsics
339442 mlir::RewritePatternSet patterns (context);
340443 patterns.insert <TransposeAsElementalConversion>(context);
341444 patterns.insert <SumAsElementalConversion>(context);
445+ patterns.insert <CShiftAsElementalConversion>(context);
342446 mlir::ConversionTarget target (*context);
343447 // don't transform transpose of polymorphic arrays (not currently supported
344448 // by hlfir.elemental)
@@ -375,6 +479,24 @@ class SimplifyHLFIRIntrinsics
375479 }
376480 return true ;
377481 });
482+ target.addDynamicallyLegalOp <hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
483+ unsigned resultRank = hlfir::Entity{cshift}.getRank ();
484+ if (resultRank == 1 )
485+ return false ;
486+
487+ mlir::Value dim = cshift.getDim ();
488+ if (!dim)
489+ return false ;
490+
491+ // If DIM is present, then it must be constant to please
492+ // the conversion. In addition, ignore cases with
493+ // illegal DIM values.
494+ if (auto dimVal = fir::getIntIfConstant (dim))
495+ if (*dimVal > 0 && *dimVal <= resultRank)
496+ return false ;
497+
498+ return true ;
499+ });
378500 target.markUnknownOpDynamicallyLegal (
379501 [](mlir::Operation *) { return true ; });
380502 if (mlir::failed (mlir::applyFullConversion (getOperation (), target,
0 commit comments