|
18 | 18 | #include "flang/Optimizer/HLFIR/HLFIRDialect.h" |
19 | 19 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
20 | 20 | #include "flang/Optimizer/HLFIR/Passes.h" |
21 | | -#include "flang/Optimizer/Support/DataLayout.h" |
22 | 21 | #include "mlir/Dialect/Arith/IR/Arith.h" |
23 | 22 | #include "mlir/IR/Location.h" |
24 | 23 | #include "mlir/Pass/Pass.h" |
@@ -457,8 +456,7 @@ class CShiftConversion : public mlir::OpRewritePattern<hlfir::CShiftOp> { |
457 | 456 | // representation. |
458 | 457 | hlfir::Entity array = hlfir::Entity{cshift.getArray()}; |
459 | 458 | mlir::Type elementType = array.getFortranElementType(); |
460 | | - if (dimVal == 1 && fir::isa_trivial(elementType) && |
461 | | - !array.isSimplyContiguous()) |
| 459 | + if (dimVal == 1 && fir::isa_trivial(elementType)) |
462 | 460 | rewriter.replaceOp(cshift, genInMemCShift(rewriter, cshift, dimVal)); |
463 | 461 | else |
464 | 462 | rewriter.replaceOp(cshift, genElementalCShift(rewriter, cshift, dimVal)); |
@@ -759,30 +757,18 @@ class CShiftConversion : public mlir::OpRewritePattern<hlfir::CShiftOp> { |
759 | 757 | mlir::Value elemSize; |
760 | 758 | mlir::Value stride; |
761 | 759 | mlir::Type elementType = array.getFortranElementType(); |
762 | | - if (dimVal == 1 && mlir::isa<fir::BaseBoxType>(array.getType()) && |
763 | | - fir::isa_trivial(elementType)) { |
764 | | - mlir::ModuleOp module = cshift->getParentOfType<mlir::ModuleOp>(); |
765 | | - std::optional<mlir::DataLayout> dl = |
766 | | - fir::support::getMLIRDataLayout(module); |
767 | | - if (dl) { |
768 | | - fir::KindMapping kindMap = fir::getKindMapping(module); |
769 | | - auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash( |
770 | | - loc, elementType, *dl, kindMap); |
771 | | - size = llvm::alignTo(size, align); |
772 | | - if (size) { |
773 | | - mlir::Type indexType = builder.getIndexType(); |
774 | | - elemSize = builder.createIntegerConstant(loc, indexType, size); |
775 | | - |
776 | | - mlir::Value dimIdx = |
777 | | - builder.createIntegerConstant(loc, indexType, dimVal - 1); |
778 | | - auto boxDim = builder.create<fir::BoxDimsOp>( |
779 | | - loc, indexType, indexType, indexType, array.getBase(), dimIdx); |
780 | | - stride = boxDim.getByteStride(); |
781 | | - } |
782 | | - } |
| 760 | + if (dimVal == 1 && mlir::isa<fir::BaseBoxType>(array.getType())) { |
| 761 | + mlir::Type indexType = builder.getIndexType(); |
| 762 | + elemSize = |
| 763 | + builder.create<fir::BoxEleSizeOp>(loc, indexType, array.getBase()); |
| 764 | + mlir::Value dimIdx = |
| 765 | + builder.createIntegerConstant(loc, indexType, dimVal - 1); |
| 766 | + auto boxDim = builder.create<fir::BoxDimsOp>( |
| 767 | + loc, indexType, indexType, indexType, array.getBase(), dimIdx); |
| 768 | + stride = boxDim.getByteStride(); |
783 | 769 | } |
784 | 770 |
|
785 | | - if (!elemSize || !stride) { |
| 771 | + if (array.isSimplyContiguous() || !elemSize || !stride) { |
786 | 772 | genDimensionShift(loc, builder, shiftVal, /*exposeContiguity=*/false, |
787 | 773 | oneBasedIndices); |
788 | 774 | return {}; |
|
0 commit comments