-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[flang] Support non-index shape/shift/slice for CG box operations. #124625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
That is another problem uncovered during hlfir.reshape inlining, where the shape bits could be any integer type. This patch adds explicit convertions to `index` type where needed.
|
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-codegen Author: Slava Zakharin (vzakhari) ChangesThat is another problem uncovered during hlfir.reshape inlining, Full diff: https://github.com/llvm/llvm-project/pull/124625.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 6ff2c20d744537..17f76faa9b0025 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -1675,22 +1675,26 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
this->attachTBAATag(storeOp, boxTy, boxTy, nullptr);
return storage;
}
-};
-/// Compute the extent of a triplet slice (lb:ub:step).
-static mlir::Value
-computeTripletExtent(mlir::ConversionPatternRewriter &rewriter,
- mlir::Location loc, mlir::Value lb, mlir::Value ub,
- mlir::Value step, mlir::Value zero, mlir::Type type) {
- mlir::Value extent = rewriter.create<mlir::LLVM::SubOp>(loc, type, ub, lb);
- extent = rewriter.create<mlir::LLVM::AddOp>(loc, type, extent, step);
- extent = rewriter.create<mlir::LLVM::SDivOp>(loc, type, extent, step);
- // If the resulting extent is negative (`ub-lb` and `step` have different
- // signs), zero must be returned instead.
- auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
- loc, mlir::LLVM::ICmpPredicate::sgt, extent, zero);
- return rewriter.create<mlir::LLVM::SelectOp>(loc, cmp, extent, zero);
-}
+ /// Compute the extent of a triplet slice (lb:ub:step).
+ mlir::Value computeTripletExtent(mlir::ConversionPatternRewriter &rewriter,
+ mlir::Location loc, mlir::Value lb,
+ mlir::Value ub, mlir::Value step,
+ mlir::Value zero, mlir::Type type) const {
+ lb = this->integerCast(loc, rewriter, type, lb);
+ ub = this->integerCast(loc, rewriter, type, ub);
+ step = this->integerCast(loc, rewriter, type, step);
+ zero = this->integerCast(loc, rewriter, type, zero);
+ mlir::Value extent = rewriter.create<mlir::LLVM::SubOp>(loc, type, ub, lb);
+ extent = rewriter.create<mlir::LLVM::AddOp>(loc, type, extent, step);
+ extent = rewriter.create<mlir::LLVM::SDivOp>(loc, type, extent, step);
+ // If the resulting extent is negative (`ub-lb` and `step` have different
+ // signs), zero must be returned instead.
+ auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
+ loc, mlir::LLVM::ICmpPredicate::sgt, extent, zero);
+ return rewriter.create<mlir::LLVM::SelectOp>(loc, cmp, extent, zero);
+ }
+};
/// Create a generic box on a memory reference. This conversions lowers the
/// abstract box to the appropriate, initialized descriptor.
@@ -1851,14 +1855,16 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
// translating everything to values in the descriptor wherever the entity
// has a dynamic array dimension.
for (unsigned di = 0, descIdx = 0; di < rank; ++di) {
- mlir::Value extent = operands[shapeOffset];
+ mlir::Value extent =
+ integerCast(loc, rewriter, i64Ty, operands[shapeOffset]);
mlir::Value outerExtent = extent;
bool skipNext = false;
if (hasSlice) {
- mlir::Value off = operands[sliceOffset];
+ mlir::Value off =
+ integerCast(loc, rewriter, i64Ty, operands[sliceOffset]);
mlir::Value adj = one;
if (hasShift)
- adj = operands[shiftOffset];
+ adj = integerCast(loc, rewriter, i64Ty, operands[shiftOffset]);
auto ao = rewriter.create<mlir::LLVM::SubOp>(loc, i64Ty, off, adj);
if (constRows > 0) {
cstInteriorIndices.push_back(ao);
@@ -1895,7 +1901,7 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
// the lower bound.
if (hasShift && !(hasSlice || hasSubcomp || hasSubstr) &&
(isaPointerOrAllocatable || !normalizedLowerBound(xbox))) {
- lb = operands[shiftOffset];
+ lb = integerCast(loc, rewriter, i64Ty, operands[shiftOffset]);
auto extentIsEmpty = rewriter.create<mlir::LLVM::ICmpOp>(
loc, mlir::LLVM::ICmpPredicate::eq, extent, zero);
lb = rewriter.create<mlir::LLVM::SelectOp>(loc, extentIsEmpty, one,
@@ -1907,9 +1913,12 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
// store step (scaled by shaped extent)
mlir::Value step = prevDimByteStride;
- if (hasSlice)
- step = rewriter.create<mlir::LLVM::MulOp>(loc, i64Ty, step,
- operands[sliceOffset + 2]);
+ if (hasSlice) {
+ mlir::Value sliceStep =
+ integerCast(loc, rewriter, i64Ty, operands[sliceOffset + 2]);
+ step =
+ rewriter.create<mlir::LLVM::MulOp>(loc, i64Ty, step, sliceStep);
+ }
dest = insertStride(rewriter, loc, dest, descIdx, step);
++descIdx;
}
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 4c9f965e1241a0..6d7a4a09918e5a 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -1909,6 +1909,67 @@ func.func @xembox0(%arg0: !fir.ref<!fir.array<?xi32>>) {
// CHECK: %[[BOX10:.*]] = llvm.insertvalue %[[BASE_PTR]], %[[BOX9]][0] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>
// CHECK: llvm.store %[[BOX10]], %[[ALLOCA]] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>, !llvm.ptr
+// Test i32 shape/shift/slice:
+func.func @xembox0_i32(%arg0: !fir.ref<!fir.array<?xi32>>) {
+ %c0 = arith.constant 0 : i32
+ %c0_i64 = arith.constant 0 : i64
+ %0 = fircg.ext_embox %arg0(%c0) origin %c0[%c0, %c0, %c0] : (!fir.ref<!fir.array<?xi32>>, i32, i32, i32, i32, i32) -> !fir.box<!fir.array<?xi32>>
+ return
+}
+// CHECK-LABEL: llvm.func @xembox0_i32(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr
+// CHECK: %[[ALLOCA_SIZE:.*]] = llvm.mlir.constant(1 : i32) : i32
+// GENERIC: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_SIZE]] x !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+// AMDGPU: %[[AA:.*]] = llvm.alloca %[[ALLOCA_SIZE]] x !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+// AMDGPU: %[[ALLOCA:.*]] = llvm.addrspacecast %[[AA]] : !llvm.ptr<5> to !llvm.ptr
+// CHECK: %[[C0_I32:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[TYPE:.*]] = llvm.mlir.constant(9 : i32) : i32
+// CHECK: %[[NULL:.*]] = llvm.mlir.zero : !llvm.ptr
+// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[NULL]][1]
+// CHECK: %[[ELEM_LEN_I64:.*]] = llvm.ptrtoint %[[GEP]] : !llvm.ptr to i64
+// CHECK: %[[BOX0:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>
+// CHECK: %[[BOX1:.*]] = llvm.insertvalue %[[ELEM_LEN_I64]], %[[BOX0]][1] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>
+// CHECK: %[[VERSION:.*]] = llvm.mlir.constant(20240719 : i32) : i32
+// CHECK: %[[BOX2:.*]] = llvm.insertvalue %[[VERSION]], %[[BOX1]][2] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>
+// CHECK: %[[RANK:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[RANK_I8:.*]] = llvm.trunc %[[RANK]] : i32 to i8
+// CHECK: %[[BOX3:.*]] = llvm.insertvalue %[[RANK_I8]], %[[BOX2]][3] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>
+// CHECK: %[[TYPE_I8:.*]] = llvm.trunc %[[TYPE]] : i32 to i8
+// CHECK: %[[BOX4:.*]] = llvm.insertvalue %[[TYPE_I8]], %[[BOX3]][4] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>
+// CHECK: %[[ATTR:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[ATTR_I8:.*]] = llvm.trunc %[[ATTR]] : i32 to i8
+// CHECK: %[[BOX5:.*]] = llvm.insertvalue %[[ATTR_I8]], %[[BOX4]][5] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>
+// CHECK: %[[F18ADDENDUM:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[F18ADDENDUM_I8:.*]] = llvm.trunc %[[F18ADDENDUM]] : i32 to i8
+// CHECK: %[[BOX6:.*]] = llvm.insertvalue %[[F18ADDENDUM_I8]], %[[BOX5]][6] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>
+// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[C0_1:.*]] = llvm.sext %[[C0_I32]] : i32 to i64
+// CHECK: %[[C0_2:.*]] = llvm.sext %[[C0_I32]] : i32 to i64
+// CHECK: %[[C0_3:.*]] = llvm.sext %[[C0_I32]] : i32 to i64
+// CHECK: %[[ADJUSTED_OFFSET:.*]] = llvm.sub %[[C0_2]], %[[C0_3]] : i64
+// CHECK: %[[DIM_OFFSET:.*]] = llvm.mul %[[ADJUSTED_OFFSET]], %[[ONE]] : i64
+// CHECK: %[[PTR_OFFSET:.*]] = llvm.add %[[DIM_OFFSET]], %[[ZERO]] : i64
+// CHECK: %[[C0_4:.*]] = llvm.sext %[[C0_I32]] : i32 to i64
+// CHECK: %[[C0_5:.*]] = llvm.sext %[[C0_I32]] : i32 to i64
+// CHECK: %[[C0_6:.*]] = llvm.sext %[[C0_I32]] : i32 to i64
+// CHECK: %[[EXTENT0:.*]] = llvm.sub %[[C0_5]], %[[C0_4]] : i64
+// CHECK: %[[EXTENT1:.*]] = llvm.add %[[EXTENT0]], %[[C0_6]] : i64
+// CHECK: %[[EXTENT2:.*]] = llvm.sdiv %[[EXTENT1]], %[[C0_6]] : i64
+// CHECK: %[[EXTENT_CMP:.*]] = llvm.icmp "sgt" %[[EXTENT2]], %[[ZERO]] : i64
+// CHECK: %[[EXTENT:.*]] = llvm.select %[[EXTENT_CMP]], %[[EXTENT2]], %[[ZERO]] : i1, i64
+// CHECK: %[[BOX7:.*]] = llvm.insertvalue %[[ONE]], %[[BOX6]][7, 0, 0] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>
+// CHECK: %[[BOX8:.*]] = llvm.insertvalue %[[EXTENT]], %[[BOX7]][7, 0, 1] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>
+// CHECK: %[[C0_7:.*]] = llvm.sext %[[C0_I32]] : i32 to i64
+// CHECK: %[[STRIDE:.*]] = llvm.mul %[[ELEM_LEN_I64]], %[[C0_7]] : i64
+// CHECK: %[[BOX9:.*]] = llvm.insertvalue %[[STRIDE]], %[[BOX8]][7, 0, 2] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>
+// CHECK: %[[PREV_DIM:.*]] = llvm.mul %[[ELEM_LEN_I64]], %[[C0_1]] : i64
+// CHECK: %[[PREV_PTROFF:.*]] = llvm.mul %[[ONE]], %[[C0_1]] : i64
+// CHECK: %[[BASE_PTR:.*]] = llvm.getelementptr %[[ARG0]][%[[PTR_OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+// CHECK: %[[BOX10:.*]] = llvm.insertvalue %[[BASE_PTR]], %[[BOX9]][0] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>
+// CHECK: llvm.store %[[BOX10]], %[[ALLOCA]] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>, !llvm.ptr
+
// Check adjustment of element scaling factor.
func.func @xembox1(%arg0: !fir.ref<!fir.array<?x!fir.char<1, 10>>>) {
|
clementval
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
jeanPerier
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
tblah
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A minor nit, otherwise LGTM
| mlir::Location loc, mlir::Value lb, | ||
| mlir::Value ub, mlir::Value step, | ||
| mlir::Value zero, mlir::Type type) const { | ||
| lb = this->integerCast(loc, rewriter, type, lb); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are the this-> pointers required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
integerCast is protected within ConvertFIRToLLVMPattern, so either this-> or FIROpConversion:: should be used. I chose the former :)
That is another problem uncovered during hlfir.reshape inlining,
where the shape bits could be any integer type.
This patch adds explicit convertions to
indextype where needed.