Skip to content

Conversation

@vzakhari
Copy link
Contributor

This patch inlines hlfir.reshape for simple cases, such as
when there is no ORDER argument; and when PAD is present,
only the trivial types are handled.

This patch inlines hlfir.reshape for simple cases, such as
when there is no ORDER argument; and when PAD is present,
only the trivial types are handled.
@vzakhari vzakhari requested review from jeanPerier and tblah January 28, 2025 03:21
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Jan 28, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 28, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

Changes

This patch inlines hlfir.reshape for simple cases, such as
when there is no ORDER argument; and when PAD is present,
only the trivial types are handled.


Patch is 29.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124683.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp (+208)
  • (added) flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir (+216)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index fe7ae0eeed3cc3..35071361fa16b8 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -951,6 +951,213 @@ class DotProductConversion
   }
 };
 
+class ReshapeAsElementalConversion
+    : public mlir::OpRewritePattern<hlfir::ReshapeOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::ReshapeOp>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::ReshapeOp reshape,
+                  mlir::PatternRewriter &rewriter) const override {
+    // Do not inline RESHAPE with ORDER yet. The runtime implementation
+    // may be good enough, unless the temporary creation overhead
+    // is high.
+    // TODO: If ORDER is constant, then we can still easily inline.
+    // TODO: If the result's rank is 1, then we can assume ORDER == (/1/).
+    if (reshape.getOrder())
+      return rewriter.notifyMatchFailure(reshape,
+                                         "RESHAPE with ORDER argument");
+
+    // Verify that the element types of ARRAY, PAD and the result
+    // match before doing any transformations.
+    hlfir::Entity result = hlfir::Entity{reshape};
+    hlfir::Entity array = hlfir::Entity{reshape.getArray()};
+    mlir::Type elementType = array.getFortranElementType();
+    if (result.getFortranElementType() != elementType)
+      return rewriter.notifyMatchFailure(
+          reshape, "ARRAY and result have different types");
+    mlir::Value pad = reshape.getPad();
+    if (pad && hlfir::getFortranElementType(pad.getType()) != elementType)
+      return rewriter.notifyMatchFailure(reshape,
+                                         "ARRAY and PAD have different types");
+
+    // TODO: selecting between ARRAY and PAD of non-trivial element types
+    // requires more work. We have to select between two references
+    // to elements in ARRAY and PAD. This requires conditional
+    // bufferization of the element, if ARRAY/PAD is an expression.
+    if (pad && !fir::isa_trivial(elementType))
+      return rewriter.notifyMatchFailure(reshape,
+                                         "PAD present with non-trivial type");
+
+    mlir::Location loc = reshape.getLoc();
+    fir::FirOpBuilder builder{rewriter, reshape.getOperation()};
+    // Assume that all the indices arithmetic does not overflow
+    // the IndexType.
+    builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nuw);
+
+    llvm::SmallVector<mlir::Value, 1> typeParams;
+    hlfir::genLengthParameters(loc, builder, array, typeParams);
+
+    // Fetch the extents of ARRAY, PAD and result beforehand.
+    llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents =
+        hlfir::genExtentsVector(loc, builder, array);
+
+    mlir::Value arraySize, padSize;
+    llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents;
+    if (pad) {
+      // If PAD is present, we have to use array size to start taking
+      // elements from the PAD array.
+      arraySize = computeArraySize(loc, builder, arrayExtents);
+
+      padExtents = hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad});
+      // PAD size is needed to wrap around the linear index addressing
+      // the PAD array.
+      padSize = computeArraySize(loc, builder, padExtents);
+    }
+    hlfir::Entity shape = hlfir::Entity{reshape.getShape()};
+    llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents;
+    mlir::Type indexType = builder.getIndexType();
+    for (int idx = 0; idx < result.getRank(); ++idx)
+      resultExtents.push_back(hlfir::loadElementAt(
+          loc, builder, shape,
+          builder.createIntegerConstant(loc, indexType, idx + 1)));
+    auto resultShape = builder.create<fir::ShapeOp>(loc, resultExtents);
+
+    auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+                         mlir::ValueRange inputIndices) -> hlfir::Entity {
+      mlir::Value linearIndex =
+          computeLinearIndex(loc, builder, resultExtents, inputIndices);
+      fir::IfOp ifOp;
+      if (pad) {
+        // PAD is present. Check if this element comes from the PAD array.
+        mlir::Value isInsideArray = builder.create<mlir::arith::CmpIOp>(
+            loc, mlir::arith::CmpIPredicate::ult, linearIndex, arraySize);
+        ifOp = builder.create<fir::IfOp>(loc, elementType, isInsideArray,
+                                         /*withElseRegion=*/true);
+
+        // In the 'else' block, return an element from the PAD.
+        builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+        // Subtract the ARRAY size from the zero-based linear index
+        // to get the zero-based linear index into PAD.
+        mlir::Value padLinearIndex =
+            builder.create<mlir::arith::SubIOp>(loc, linearIndex, arraySize);
+        // PAD wraps around, when additional elements are needed.
+        padLinearIndex =
+            builder.create<mlir::arith::RemUIOp>(loc, padLinearIndex, padSize);
+        llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices =
+            delinearizeIndex(loc, builder, padExtents, padLinearIndex);
+        mlir::Value padElement =
+            hlfir::loadElementAt(loc, builder, hlfir::Entity{pad}, padIndices);
+        builder.create<fir::ResultOp>(loc, padElement);
+
+        // In the 'then' block, return an element from the ARRAY.
+        builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      }
+
+      llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices =
+          delinearizeIndex(loc, builder, arrayExtents, linearIndex);
+      mlir::Value arrayElement =
+          hlfir::loadElementAt(loc, builder, array, arrayIndices);
+
+      if (ifOp) {
+        builder.create<fir::ResultOp>(loc, arrayElement);
+        builder.setInsertionPointAfter(ifOp);
+        arrayElement = ifOp.getResult(0);
+      }
+
+      return hlfir::Entity{arrayElement};
+    };
+    hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
+        loc, builder, elementType, resultShape, typeParams, genKernel,
+        /*isUnordered=*/true,
+        /*polymorphicMold=*/result.isPolymorphic() ? array : mlir::Value{},
+        reshape.getResult().getType());
+    assert(elementalOp.getResult().getType() == reshape.getResult().getType());
+    rewriter.replaceOp(reshape, elementalOp);
+    return mlir::success();
+  }
+
+private:
+  /// Compute zero-based linear index given an array extents
+  /// and one-based indices:
+  ///   \p extents: [e0, e1, ..., en]
+  ///   \p indices: [i0, i1, ..., in]
+  ///
+  /// linear-index :=
+  ///   (...((in-1)*e(n-1)+(i(n-1)-1))*e(n-2)+...)*e0+(i0-1)
+  static mlir::Value computeLinearIndex(mlir::Location loc,
+                                        fir::FirOpBuilder &builder,
+                                        mlir::ValueRange extents,
+                                        mlir::ValueRange indices) {
+    std::size_t rank = extents.size();
+    assert(rank = indices.size());
+    mlir::Type indexType = builder.getIndexType();
+    mlir::Value zero = builder.createIntegerConstant(loc, indexType, 0);
+    mlir::Value one = builder.createIntegerConstant(loc, indexType, 1);
+    mlir::Value linearIndex = zero;
+    for (auto idx : llvm::enumerate(llvm::reverse(indices))) {
+      mlir::Value tmp = builder.create<mlir::arith::SubIOp>(
+          loc, builder.createConvert(loc, indexType, idx.value()), one);
+      tmp = builder.create<mlir::arith::AddIOp>(loc, linearIndex, tmp);
+      if (idx.index() + 1 < rank)
+        tmp = builder.create<mlir::arith::MulIOp>(
+            loc, tmp,
+            builder.createConvert(loc, indexType,
+                                  extents[rank - idx.index() - 2]));
+
+      linearIndex = tmp;
+    }
+    return linearIndex;
+  }
+
+  /// Compute one-based array indices from the given zero-based \p linearIndex
+  /// and the array \p extents [e0, e1, ..., en].
+  ///   i0 := linearIndex % e0 + 1
+  ///   linearIndex := linearIndex / e0
+  ///   i1 := linearIndex % e1 + 1
+  ///   linearIndex := linearIndex / e1
+  ///   ...
+  ///   i(n-1) := linearIndex % e(n-1) + 1
+  ///   linearIndex := linearIndex / e(n-1)
+  ///   in := linearIndex + 1
+  static llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
+  delinearizeIndex(mlir::Location loc, fir::FirOpBuilder &builder,
+                   mlir::ValueRange extents, mlir::Value linearIndex) {
+    llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
+    mlir::Type indexType = builder.getIndexType();
+    mlir::Value one = builder.createIntegerConstant(loc, indexType, 1);
+    linearIndex = builder.createConvert(loc, indexType, linearIndex);
+
+    for (std::size_t dim = 0; dim < extents.size(); ++dim) {
+      mlir::Value currentIndex;
+      if (dim == extents.size() - 1) {
+        currentIndex = linearIndex;
+      } else {
+        mlir::Value extent =
+            builder.createConvert(loc, indexType, extents[dim]);
+        currentIndex =
+            builder.create<mlir::arith::RemUIOp>(loc, linearIndex, extent);
+        linearIndex =
+            builder.create<mlir::arith::DivUIOp>(loc, linearIndex, extent);
+      }
+      indices.push_back(
+          builder.create<mlir::arith::AddIOp>(loc, currentIndex, one));
+    }
+    return indices;
+  }
+
+  static mlir::Value computeArraySize(mlir::Location loc,
+                                      fir::FirOpBuilder &builder,
+                                      mlir::ValueRange extents) {
+    mlir::Type indexType = builder.getIndexType();
+    mlir::Value size = builder.createIntegerConstant(loc, indexType, 1);
+    for (auto extent : extents)
+      size = builder.create<mlir::arith::MulIOp>(
+          loc, size, builder.createConvert(loc, indexType, extent));
+    return size;
+  }
+};
+
 class SimplifyHLFIRIntrinsics
     : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
 public:
@@ -987,6 +1194,7 @@ class SimplifyHLFIRIntrinsics
       patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);
 
     patterns.insert<DotProductConversion>(context);
+    patterns.insert<ReshapeAsElementalConversion>(context);
 
     if (mlir::failed(mlir::applyPatternsGreedily(
             getOperation(), std::move(patterns), config))) {
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir
new file mode 100644
index 00000000000000..ad8093335556c0
--- /dev/null
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir
@@ -0,0 +1,216 @@
+// Test hlfir.reshape simplification to hlfir.elemental:
+// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s
+
+func.func @reshape_simple(%arg0: !fir.box<!fir.array<?xf32>>, %arg1: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?xf32> {
+  %res = hlfir.reshape %arg0 %arg1 : (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?xf32>
+  return %res : !hlfir.expr<?xf32>
+}
+// CHECK-LABEL:   func.func @reshape_simple(
+// CHECK-SAME:                              %[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_2]])  : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+// CHECK:           %[[VAL_5:.*]] = fir.load %[[VAL_4]] : !fir.ref<i32>
+// CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (i32) -> !fir.shape<1>
+// CHECK:           %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// CHECK:           ^bb0(%[[VAL_8:.*]]: index):
+// CHECK:             %[[VAL_9:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// CHECK:             %[[VAL_10:.*]] = arith.subi %[[VAL_9]]#0, %[[VAL_2]] overflow<nuw> : index
+// CHECK:             %[[VAL_11:.*]] = arith.addi %[[VAL_8]], %[[VAL_10]] overflow<nuw> : index
+// CHECK:             %[[VAL_12:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_11]])  : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// CHECK:             %[[VAL_13:.*]] = fir.load %[[VAL_12]] : !fir.ref<f32>
+// CHECK:             hlfir.yield_element %[[VAL_13]] : f32
+// CHECK:           }
+// CHECK:           return %[[VAL_7]] : !hlfir.expr<?xf32>
+// CHECK:         }
+
+func.func @reshape_with_pad(%arg0: !fir.box<!fir.array<?x?x?xf32>>, %arg1: !fir.ref<!fir.array<2xi32>>, %arg2: !fir.box<!fir.array<?x?x?xf32>>) -> !hlfir.expr<?x?xf32> {
+  %res = hlfir.reshape %arg0 %arg1 pad %arg2 : (!fir.box<!fir.array<?x?x?xf32>>, !fir.ref<!fir.array<2xi32>>, !fir.box<!fir.array<?x?x?xf32>>) -> !hlfir.expr<?x?xf32>
+  return %res : !hlfir.expr<?x?xf32>
+}
+// CHECK-LABEL:   func.func @reshape_with_pad(
+// CHECK-SAME:                                %[[VAL_0:.*]]: !fir.box<!fir.array<?x?x?xf32>>,
+// CHECK-SAME:                                %[[VAL_1:.*]]: !fir.ref<!fir.array<2xi32>>,
+// CHECK-SAME:                                %[[VAL_2:.*]]: !fir.box<!fir.array<?x?x?xf32>>) -> !hlfir.expr<?x?xf32> {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK:           %[[ARRAY_DIM0:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK:           %[[ARRAY_DIM1:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK:           %[[ARRAY_DIM2:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK:           %[[VAL_9:.*]] = arith.muli %[[ARRAY_DIM0]]#1, %[[ARRAY_DIM1]]#1 overflow<nuw> : index
+// CHECK:           %[[ARRAY_SIZE:.*]] = arith.muli %[[VAL_9]], %[[ARRAY_DIM2]]#1 overflow<nuw> : index
+// CHECK:           %[[PAD_DIM0:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK:           %[[PAD_DIM1:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK:           %[[PAD_DIM2:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK:           %[[VAL_14:.*]] = arith.muli %[[PAD_DIM0]]#1, %[[PAD_DIM1]]#1 overflow<nuw> : index
+// CHECK:           %[[PAD_SIZE:.*]] = arith.muli %[[VAL_14]], %[[PAD_DIM2]]#1 overflow<nuw> : index
+// CHECK:           %[[VAL_16:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_4]])  : (!fir.ref<!fir.array<2xi32>>, index) -> !fir.ref<i32>
+// CHECK:           %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK:           %[[VAL_18:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_3]])  : (!fir.ref<!fir.array<2xi32>>, index) -> !fir.ref<i32>
+// CHECK:           %[[VAL_19:.*]] = fir.load %[[VAL_18]] : !fir.ref<i32>
+// CHECK:           %[[VAL_20:.*]] = fir.shape %[[VAL_17]], %[[VAL_19]] : (i32, i32) -> !fir.shape<2>
+// CHECK:           %[[VAL_21:.*]] = hlfir.elemental %[[VAL_20]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
+// CHECK:           ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index):
+// CHECK:             %[[VAL_24:.*]] = arith.subi %[[VAL_23]], %[[VAL_4]] overflow<nuw> : index
+// CHECK:             %[[VAL_25:.*]] = fir.convert %[[VAL_17]] : (i32) -> index
+// CHECK:             %[[VAL_26:.*]] = arith.muli %[[VAL_24]], %[[VAL_25]] overflow<nuw> : index
+// CHECK:             %[[VAL_27:.*]] = arith.subi %[[VAL_22]], %[[VAL_4]] overflow<nuw> : index
+// CHECK:             %[[LINEAR_INDEX:.*]] = arith.addi %[[VAL_26]], %[[VAL_27]] overflow<nuw> : index
+// CHECK:             %[[IS_WITHIN_ARRAY:.*]] = arith.cmpi ult, %[[LINEAR_INDEX]], %[[ARRAY_SIZE]] : index
+// CHECK:             %[[VAL_30:.*]] = fir.if %[[IS_WITHIN_ARRAY]] -> (f32) {
+// CHECK:               %[[VAL_31:.*]] = arith.remui %[[LINEAR_INDEX]], %[[ARRAY_DIM0]]#1 : index
+// CHECK:               %[[VAL_32:.*]] = arith.divui %[[LINEAR_INDEX]], %[[ARRAY_DIM0]]#1 : index
+// CHECK:               %[[ARRAY_IDX0:.*]] = arith.addi %[[VAL_31]], %[[VAL_4]] overflow<nuw> : index
+// CHECK:               %[[VAL_34:.*]] = arith.remui %[[VAL_32]], %[[ARRAY_DIM1]]#1 : index
+// CHECK:               %[[VAL_35:.*]] = arith.divui %[[VAL_32]], %[[ARRAY_DIM1]]#1 : index
+// CHECK:               %[[ARRAY_IDX1:.*]] = arith.addi %[[VAL_34]], %[[VAL_4]] overflow<nuw> : index
+// CHECK:               %[[ARRAY_IDX2:.*]] = arith.addi %[[VAL_35]], %[[VAL_4]] overflow<nuw> : index
+// CHECK:               %[[VAL_38:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_39:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_40:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_41:.*]] = arith.subi %[[VAL_38]]#0, %[[VAL_4]] overflow<nuw> : index
+// CHECK:               %[[VAL_42:.*]] = arith.addi %[[ARRAY_IDX0]], %[[VAL_41]] overflow<nuw> : index
+// CHECK:               %[[VAL_43:.*]] = arith.subi %[[VAL_39]]#0, %[[VAL_4]] overflow<nuw> : index
+// CHECK:               %[[VAL_44:.*]] = arith.addi %[[ARRAY_IDX1]], %[[VAL_43]] overflow<nuw> : index
+// CHECK:               %[[VAL_45:.*]] = arith.subi %[[VAL_40]]#0, %[[VAL_4]] overflow<nuw> : index
+// CHECK:               %[[VAL_46:.*]] = arith.addi %[[ARRAY_IDX2]], %[[VAL_45]] overflow<nuw> : index
+// CHECK:               %[[VAL_47:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_42]], %[[VAL_44]], %[[VAL_46]])  : (!fir.box<!fir.array<?x?x?xf32>>, index, index, index) -> !fir.ref<f32>
+// CHECK:               %[[VAL_48:.*]] = fir.load %[[VAL_47]] : !fir.ref<f32>
+// CHECK:               fir.result %[[VAL_48]] : f32
+// CHECK:             } else {
+// CHECK:               %[[PAD_LINEAR_INDEX:.*]] = arith.subi %[[LINEAR_INDEX]], %[[ARRAY_SIZE]] overflow<nuw> : index
+// CHECK:               %[[PAD_LINEAR_INDEX_MOD:.*]] = arith.remui %[[PAD_LINEAR_INDEX]], %[[PAD_SIZE]] : index
+// CHECK:               %[[VAL_51:.*]] = arith.remui %[[PAD_LINEAR_INDEX_MOD]], %[[PAD_DIM0]]#1 : index
+// CHECK:               %[[VAL_52:.*]] = arith.divui %[[PAD_LINEAR_INDEX_MOD]], %[[PAD_DIM0]]#1 : index
+// CHECK:               %[[PAD_IDX0:.*]] = arith.addi %[[VAL_51]], %[[VAL_4]] overflow<nuw> : index
+// CHECK:               %[[VAL_54:.*]] = arith.remui %[[VAL_52]], %[[PAD_DIM1]]#1 : index
+// CHECK:               %[[VAL_55:.*]] = arith.divui %[[VAL_52]], %[[PAD_DIM1]]#1 : index
+// CHECK:               %[[PAD_IDX1:.*]] = arith.addi %[[VAL_54]], %[[VAL_4]] overflow<nuw> : index
+// CHECK:               %[[PAD_IDX2:.*]] = arith.addi %[[VAL_55]], %[[VAL_4]] overflow<nuw> : index
+// CHECK:               %[[VAL_58:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_59:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_60:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK:               %[[VAL_61:.*]] = arith.subi %[[VAL_58]]#0, %[[VAL_4]] overflow<nuw> : index
+// CHECK:               %[[VAL_62:.*]] = arith.addi %[[PAD_IDX0]], %[[VAL_61]] overflow<nuw> : index
+// CHECK:               %[[VAL_63:.*]] = arith.subi %[[VAL_59]]#0, %[[VAL_4]] overflow<nuw> : index
+// CHECK:               %[[VAL_64:.*]] = arith.addi %[[PAD_IDX1]], %[[VAL_63]] overflow<nuw> : index
+// CHECK:               %[[VAL_65:.*]] = arith.subi %[[VAL_60]]#0, %[[VAL_4]] overflow<nuw> : index
+// CHECK:               %[[VAL_66:.*]] = arith.addi %[[PAD_IDX2]], %[[VAL_65]] overflow<nuw> : index
+// CHECK:               %[[VAL_67:.*]] = hlfir.designate %[[VAL_2]] (%[[VAL_62]], %[[...
[truncated]

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Slava, the logic looks good. One point about PAD dynamic optionality may need to be taken into account.

if (result.getFortranElementType() != elementType)
return rewriter.notifyMatchFailure(
reshape, "ARRAY and result have different types");
mlir::Value pad = reshape.getPad();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PAD is dynamically optional. If its actual argument is an OPTIONAL/POINTER/ALLOCATABLE, its presence should be checked at runtime.

You probably need to do something about that here (or at least to detect and do not do the transformation for now).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I moved the reads from PAD under the check of whether we have to read from it or not.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with Jean's comments.

Just some minor comments from me.

I don't see any clear regressions in spec2017 on aarch64. The results from my machine were a bit noisy today but I can rule out anything serious. I'll update if our post-commit CI finds anything.

@vzakhari vzakhari merged commit 6160a67 into llvm:main Jan 30, 2025
8 checks passed
mlir::ValueRange extents,
mlir::ValueRange indices) {
std::size_t rank = extents.size();
assert(rank = indices.size());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Should be ==? GCC was giving a warning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! == indeed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by 381416a

vzakhari added a commit to vzakhari/llvm-project that referenced this pull request Jan 31, 2025
`llvm::enumerate(llvm::reverse(ValueRange))` added in llvm#124683 does not work
on Windows: https://lab.llvm.org/buildbot/#/builders/124/builds/322
omjavaid pushed a commit that referenced this pull request Feb 3, 2025
`llvm::enumerate(llvm::reverse(ValueRange))` added in #124683 does not work
on Windows: https://lab.llvm.org/buildbot/#/builders/124/builds/322
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
`llvm::enumerate(llvm::reverse(ValueRange))` added in llvm#124683 does not work
on Windows: https://lab.llvm.org/buildbot/#/builders/124/builds/322
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants