Skip to content

Commit 15ede6a

Browse files
committed
[flang] Inline hlfir.cshift as hlfir.elemental.
1 parent e0f3410 commit 15ede6a

File tree

2 files changed

+425
-0
lines changed

2 files changed

+425
-0
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "flang/Optimizer/Builder/Complex.h"
1414
#include "flang/Optimizer/Builder/FIRBuilder.h"
1515
#include "flang/Optimizer/Builder/HLFIRTools.h"
16+
#include "flang/Optimizer/Builder/IntrinsicCall.h"
1617
#include "flang/Optimizer/Dialect/FIRDialect.h"
1718
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
1819
#include "flang/Optimizer/HLFIR/HLFIROps.h"
@@ -331,6 +332,107 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
331332
}
332333
};
333334

335+
class CShiftAsElementalConversion
336+
: public mlir::OpRewritePattern<hlfir::CShiftOp> {
337+
public:
338+
using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
339+
340+
explicit CShiftAsElementalConversion(mlir::MLIRContext *ctx)
341+
: OpRewritePattern(ctx) {
342+
setHasBoundedRewriteRecursion();
343+
}
344+
345+
llvm::LogicalResult
346+
matchAndRewrite(hlfir::CShiftOp cshift,
347+
mlir::PatternRewriter &rewriter) const override {
348+
using Fortran::common::maxRank;
349+
350+
mlir::Location loc = cshift.getLoc();
351+
fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
352+
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType());
353+
assert(expr && "expected an expression type for the result of hlfir.sum");
354+
mlir::Type elementType = expr.getElementType();
355+
hlfir::Entity array = hlfir::Entity{cshift.getArray()};
356+
mlir::Value arrayShape = hlfir::genShape(loc, builder, array);
357+
llvm::SmallVector<mlir::Value> arrayExtents =
358+
hlfir::getExplicitExtentsFromShape(arrayShape, builder);
359+
unsigned arrayRank = expr.getRank();
360+
llvm::SmallVector<mlir::Value, 1> typeParams;
361+
hlfir::genLengthParameters(loc, builder, array, typeParams);
362+
hlfir::Entity shift = hlfir::Entity{cshift.getShift()};
363+
// The new index computation involves MODULO, which is not implemented
364+
// for IndexType, so use I64 instead.
365+
mlir::Type calcType = builder.getI64Type();
366+
367+
mlir::Value one = builder.createIntegerConstant(loc, calcType, 1);
368+
mlir::Value shiftVal;
369+
if (shift.isScalar()) {
370+
shiftVal = hlfir::loadTrivialScalar(loc, builder, shift);
371+
shiftVal = builder.createConvert(loc, calcType, shiftVal);
372+
}
373+
374+
int64_t dimVal = 1;
375+
if (arrayRank == 1) {
376+
// When it is a 1D CSHIFT, we may assume that the DIM argument
377+
// (whether it is present or absent) is equal to 1, otherwise,
378+
// the program is illegal.
379+
assert(shiftVal && "SHIFT must be scalar");
380+
} else {
381+
if (mlir::Value dim = cshift.getDim())
382+
dimVal = fir::getIntIfConstant(dim).value_or(0);
383+
assert(dimVal > 0 && dimVal <= arrayRank &&
384+
"DIM must be present and a positive constant not exceeding "
385+
"the array's rank");
386+
}
387+
388+
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
389+
mlir::ValueRange inputIndices) -> hlfir::Entity {
390+
llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
391+
if (!shift.isScalar()) {
392+
// When the array is not a vector, section
393+
// (s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)
394+
// of the result has a value equal to:
395+
// CSHIFT(ARRAY(s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)),
396+
// SH, 1),
397+
// where SH is either SHIFT (if scalar) or
398+
// SHIFT(s(1), s(2), ..., s(dim-1), s(dim+1), ..., s(n)).
399+
llvm::SmallVector<mlir::Value, maxRank> shiftIndices{indices};
400+
shiftIndices.erase(shiftIndices.begin() + dimVal - 1);
401+
hlfir::Entity shiftElement =
402+
hlfir::getElementAt(loc, builder, shift, shiftIndices);
403+
shiftVal = hlfir::loadTrivialScalar(loc, builder, shiftElement);
404+
shiftVal = builder.createConvert(loc, calcType, shiftVal);
405+
}
406+
407+
// Element i of the result (1-based) is element
408+
// 'MODULO(i + SH - 1, SIZE(ARRAY)) + 1' (1-based) of the original
409+
// ARRAY (or its section, when ARRAY is not a vector).
410+
mlir::Value index =
411+
builder.createConvert(loc, calcType, inputIndices[dimVal - 1]);
412+
mlir::Value extent = arrayExtents[dimVal - 1];
413+
mlir::Value newIndex =
414+
builder.create<mlir::arith::AddIOp>(loc, index, shiftVal);
415+
newIndex = builder.create<mlir::arith::SubIOp>(loc, newIndex, one);
416+
newIndex = fir::IntrinsicLibrary{builder, loc}.genModulo(
417+
calcType, {newIndex, builder.createConvert(loc, calcType, extent)});
418+
newIndex = builder.create<mlir::arith::AddIOp>(loc, newIndex, one);
419+
newIndex = builder.createConvert(loc, builder.getIndexType(), newIndex);
420+
421+
indices[dimVal - 1] = newIndex;
422+
hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
423+
return hlfir::loadTrivialScalar(loc, builder, element);
424+
};
425+
426+
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
427+
loc, builder, elementType, arrayShape, typeParams, genKernel,
428+
/*isUnordered=*/true,
429+
array.isPolymorphic() ? static_cast<mlir::Value>(array) : nullptr,
430+
cshift.getResult().getType());
431+
rewriter.replaceOp(cshift, elementalOp);
432+
return mlir::success();
433+
}
434+
};
435+
334436
class SimplifyHLFIRIntrinsics
335437
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
336438
public:
@@ -339,6 +441,7 @@ class SimplifyHLFIRIntrinsics
339441
mlir::RewritePatternSet patterns(context);
340442
patterns.insert<TransposeAsElementalConversion>(context);
341443
patterns.insert<SumAsElementalConversion>(context);
444+
patterns.insert<CShiftAsElementalConversion>(context);
342445
mlir::ConversionTarget target(*context);
343446
// don't transform transpose of polymorphic arrays (not currently supported
344447
// by hlfir.elemental)
@@ -375,6 +478,24 @@ class SimplifyHLFIRIntrinsics
375478
}
376479
return true;
377480
});
481+
target.addDynamicallyLegalOp<hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
482+
unsigned resultRank = hlfir::Entity{cshift}.getRank();
483+
if (resultRank == 1)
484+
return false;
485+
486+
mlir::Value dim = cshift.getDim();
487+
if (!dim)
488+
return false;
489+
490+
// If DIM is present, then it must be constant to please
491+
// the conversion. In addition, ignore cases with
492+
// illegal DIM values.
493+
if (auto dimVal = fir::getIntIfConstant(dim))
494+
if (*dimVal > 0 && *dimVal <= resultRank)
495+
return false;
496+
497+
return true;
498+
});
378499
target.markUnknownOpDynamicallyLegal(
379500
[](mlir::Operation *) { return true; });
380501
if (mlir::failed(mlir::applyFullConversion(getOperation(), target,

0 commit comments

Comments
 (0)