-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][LLVMIR] Add folder pass for llvm.inttoptr and llvm.ptrtoint
#143066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR is a follow-up to llvm#141891. It introduces a pass that can fold `inttoptr(ptrtoint(x)) -> x` and `ptrtoint(inttoptr(x)) -> x`. The pass takes in a list of address space bitwidths and makes sure that the folding is applied only when it's safe.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Diego Caballero (dcaballe) ChangesThis PR is a follow-up to #141891. It introduces a pass that can fold Full diff: https://github.com/llvm/llvm-project/pull/143066.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h
new file mode 100644
index 0000000000000..62a9c84241a39
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h
@@ -0,0 +1,43 @@
+//===- IntToPtrPtrToIntFolding.h - IntToPtr/PtrToInt folding ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares a pass that folds inttoptr/ptrtoint operation sequences.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
+#define MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+namespace LLVM {
+
+#define GEN_PASS_DECL_FOLDINTTOPTRPTRTOINTPASS
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
+
+/// Populate patterns that fold inttoptr/ptrtoint op sequences such as:
+///
+/// * `inttoptr(ptrtoint(x))` -> `x`
+/// * `ptrtoint(inttoptr(x))` -> `x`
+///
+/// `addressSpaceBWs` contains the pointer bitwidth for each address space. If
+/// the pointer bitwidth information is not available for a specific address
+/// space, the folding for that address space is not performed.
+///
+/// TODO: Support DLTI.
+void populateIntToPtrPtrToIntFoldingPatterns(
+ RewritePatternSet &patterns, ArrayRef<unsigned> addressSpaceBWs);
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
index 961909d5c8d27..be45213e6b95e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
@@ -73,4 +73,23 @@ def DIScopeForLLVMFuncOpPass : Pass<"ensure-debug-info-scope-on-llvm-func", "::m
];
}
+def FoldIntToPtrPtrToIntPass : Pass<"fold-llvm-inttoptr-ptrtoint", "LLVM::LLVMFuncOp"> {
+ let summary = "Fold inttoptr/ptrtoint operation sequences";
+ let description = [{
+ This pass folds sequences of inttoptr and ptrtoint operations that cancel
+ each other out. Specifically:
+ * inttoptr(ptrtoint(x)) -> x
+ * ptrtoint(inttoptr(x)) -> x
+
+ The pass takes a sequence of address space bitwidths to make sure folding
+ is safe. If the bitwidth information is not available for an address space,
+ the pass will not fold any operations involving that address space.
+ }];
+ let dependentDialects = ["LLVM::LLVMDialect"];
+ let options = [
+ ListOption<"addrSpaceBWs", "address-space-bitwidths", "unsigned",
+ "List of address space bitwidths sorted by associated index to each address space.">
+ ];
+}
+
#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
index d4ff0955c5d0e..b22280718c454 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms
DIExpressionRewriter.cpp
DIScopeForLLVMFuncOp.cpp
InlinerInterfaceImpl.cpp
+ IntToPtrPtrToIntFolding.cpp
LegalizeForExport.cpp
OptimizeForNVVM.cpp
RequestCWrappers.cpp
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp
new file mode 100644
index 0000000000000..c87a9cb2afe41
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp
@@ -0,0 +1,133 @@
+//===- IntToPtrPtrToIntFolding.cpp - IntToPtr/PtrToInt folding ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that folds inttoptr/ptrtoint operation sequences.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h"
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "fold-llvm-inttoptr-ptrtoint"
+
+namespace mlir {
+namespace LLVM {
+
+#define GEN_PASS_DEF_FOLDINTTOPTRPTRTOINTPASS
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
+
+} // namespace LLVM
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Return the bitwidth of a pointer or integer type. If the type is a pointer,
+/// return the bitwidth of the address space from `addrSpaceBWs`, if available.
+/// Return failure if the address space bitwidth is not available.
+static FailureOr<unsigned> getIntOrPtrBW(Type type,
+ ArrayRef<unsigned> addrSpaceBWs) {
+ if (auto ptrType = dyn_cast<LLVM::LLVMPointerType>(type)) {
+ unsigned addrSpace = ptrType.getAddressSpace();
+ if (addrSpace < addrSpaceBWs.size() && addrSpaceBWs[addrSpace] != 0)
+ return addrSpaceBWs[addrSpace];
+ return failure();
+ }
+
+ auto integerType = cast<IntegerType>(type);
+ return integerType.getWidth();
+}
+
+/// Check if folding inttoptr/ptrtoint is valid. Check that the original type
+/// matches the result type of the end-to-end conversion and that the input
+/// value is not truncated along the conversion chain.
+static LogicalResult canFoldIntToPtrPtrToInt(Type originalType,
+ Type intermediateType,
+ Type resultType,
+ ArrayRef<unsigned> addrSpaceBWs) {
+ // Check if the original type matches the result type.
+ // TODO: Support address space conversions?
+ // TODO: Support int trunc/ext?
+ if (originalType != resultType)
+ return failure();
+
+ // Make sure there is no data truncation with respect to the original type at
+ // any point during the conversion. Truncating the intermediate data is fine
+ // as long as the original data is not truncated.
+ auto originalBW = getIntOrPtrBW(originalType, addrSpaceBWs);
+ if (failed(originalBW))
+ return failure();
+
+ auto intermediateBW = getIntOrPtrBW(intermediateType, addrSpaceBWs);
+ if (failed(intermediateBW))
+ return failure();
+
+ if (*originalBW > *intermediateBW)
+ return failure();
+ return success();
+}
+
+/// Folds inttoptr(ptrtoint(x)) -> x or ptrtoint(inttoptr(x)) -> x.
+template <typename SrcConvOp, typename DstConvOp>
+class FoldIntToPtrPtrToInt : public OpRewritePattern<DstConvOp> {
+public:
+ FoldIntToPtrPtrToInt(MLIRContext *context, ArrayRef<unsigned> addrSpaceBWs)
+ : OpRewritePattern<DstConvOp>(context), addrSpaceBWs(addrSpaceBWs) {}
+
+ LogicalResult matchAndRewrite(DstConvOp dstConvOp,
+ PatternRewriter &rewriter) const override {
+ auto srcConvOp = dstConvOp.getArg().template getDefiningOp<SrcConvOp>();
+ if (!srcConvOp)
+ return failure();
+
+ // Check if folding is valid based on type matching and bitwidth
+ // information.
+ if (failed(canFoldIntToPtrPtrToInt(srcConvOp.getArg().getType(),
+ srcConvOp.getType(), dstConvOp.getType(),
+ addrSpaceBWs))) {
+ return failure();
+ }
+
+ rewriter.replaceOp(dstConvOp, srcConvOp.getArg());
+ return success();
+ }
+
+private:
+ SmallVector<unsigned> addrSpaceBWs;
+};
+
+/// Pass that folds inttoptr/ptrtoint operation sequences.
+struct FoldIntToPtrPtrToIntPass
+ : public LLVM::impl::FoldIntToPtrPtrToIntPassBase<
+ FoldIntToPtrPtrToIntPass> {
+ using Base =
+ LLVM::impl::FoldIntToPtrPtrToIntPassBase<FoldIntToPtrPtrToIntPass>;
+ using Base::FoldIntToPtrPtrToIntPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ LLVM::populateIntToPtrPtrToIntFoldingPatterns(patterns, addrSpaceBWs);
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::LLVM::populateIntToPtrPtrToIntFoldingPatterns(
+ RewritePatternSet &patterns, ArrayRef<unsigned> addrSpaceBWs) {
+ patterns.add<FoldIntToPtrPtrToInt<LLVM::PtrToIntOp, LLVM::IntToPtrOp>,
+ FoldIntToPtrPtrToInt<LLVM::IntToPtrOp, LLVM::PtrToIntOp>>(
+ patterns.getContext(), addrSpaceBWs);
+}
diff --git a/mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir b/mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir
new file mode 100644
index 0000000000000..30193eff93be2
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=64}))" %s | FileCheck %s --check-prefixes=CHECK-64BIT,CHECK-ALL
+// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=32}))" %s | FileCheck %s --check-prefixes=CHECK-32BIT,CHECK-ALL
+// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=64,32}))" %s | FileCheck %s --check-prefixes=CHECK-MULTI-ADDRSPACE,CHECK-ALL
+// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint))" %s | FileCheck %s --check-prefixes=CHECK-DISABLED,CHECK-ALL
+
+
+// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_fold_64bit
+// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr)
+llvm.func @test_inttoptr_ptrtoint_fold_64bit(%arg0: !llvm.ptr) -> !llvm.ptr {
+ // CHECK-64BIT-NOT: llvm.ptrtoint
+ // CHECK-64BIT-NOT: llvm.inttoptr
+ // CHECK-64BIT: llvm.return %[[ARG]]
+
+ // CHECK-32BIT-NOT: llvm.ptrtoint
+ // CHECK-32BIT-NOT: llvm.inttoptr
+ // CHECK-32BIT: llvm.return %[[ARG]]
+
+ // CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
+ // CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
+ // CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]
+
+ // CHECK-DISABLED: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+ // CHECK-DISABLED: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+ // CHECK-DISABLED: llvm.return %[[PTR]]
+
+ %0 = llvm.ptrtoint %arg0 : !llvm.ptr to i64
+ %1 = llvm.inttoptr %0 : i64 to !llvm.ptr
+ llvm.return %1 : !llvm.ptr
+}
+
+// CHECK-ALL-LABEL: @test_ptrtoint_inttoptr_fold_64bit
+// CHECK-ALL-SAME: (%[[ARG:.+]]: i64)
+llvm.func @test_ptrtoint_inttoptr_fold_64bit(%arg0: i64) -> i64 {
+ // CHECK-64BIT-NOT: llvm.inttoptr
+ // CHECK-64BIT-NOT: llvm.ptrtoint
+ // CHECK-64BIT: llvm.return %[[ARG]]
+
+ // CHECK-32BIT: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
+ // CHECK-32BIT: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
+ // CHECK-32BIT: llvm.return %[[PTR]]
+
+ // CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
+ // CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
+ // CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]
+
+ // CHECK-DISABLED: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
+ // CHECK-DISABLED: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
+ // CHECK-DISABLED: llvm.return %[[PTR]]
+
+ %0 = llvm.inttoptr %arg0 : i64 to !llvm.ptr
+ %1 = llvm.ptrtoint %0 : !llvm.ptr to i64
+ llvm.return %1 : i64
+}
+
+// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_fold_addrspace1_32bit
+// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr<1>)
+llvm.func @test_inttoptr_ptrtoint_fold_addrspace1_32bit(%arg0: !llvm.ptr<1>) -> !llvm.ptr<1> {
+ // CHECK-64BIT: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+ // CHECK-64BIT: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+ // CHECK-64BIT: llvm.return %[[PTR]]
+
+ // CHECK-32BIT: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+ // CHECK-32BIT: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+ // CHECK-32BIT: llvm.return %[[PTR]]
+
+ // CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
+ // CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
+ // CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]
+
+ // CHECK-DISABLED: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+ // CHECK-DISABLED: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+ // CHECK-DISABLED: llvm.return %[[PTR]]
+
+ %0 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i32
+ %1 = llvm.inttoptr %0 : i32 to !llvm.ptr<1>
+ llvm.return %1 : !llvm.ptr<1>
+}
+
+// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_type_mismatch
+// CHECK-ALL-SAME: (%[[ARG:.+]]: i64)
+llvm.func @test_inttoptr_ptrtoint_type_mismatch(%arg0: i64) -> i32 {
+ // CHECK-ALL: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
+ // CHECK-ALL: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
+ // CHECK-ALL: llvm.return %[[PTR]]
+
+ %0 = llvm.inttoptr %arg0 : i64 to !llvm.ptr
+ %1 = llvm.ptrtoint %0 : !llvm.ptr to i32
+ llvm.return %1 : i32
+}
+
+// CHECK-ALL-LABEL: @test_ptrtoint_inttoptr_type_mismatch
+// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr<1>)
+llvm.func @test_ptrtoint_inttoptr_type_mismatch(%arg0: !llvm.ptr<1>) -> !llvm.ptr<0> {
+ // CHECK-ALL: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
+ // CHECK-ALL: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
+ // CHECK-ALL: llvm.return %[[PTR]]
+ %0 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
+ %1 = llvm.inttoptr %0 : i64 to !llvm.ptr<0>
+ llvm.return %1 : !llvm.ptr<0>
+}
|
Dinistro
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for changing this to be a pass. Conceptually, this looks sensible, but we will require the data layout as this is otherwise very hard to use from pipelines that do not necessarily know the bit-widths during construction.
| } // namespace LLVM | ||
| } // namespace mlir | ||
|
|
||
| #endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Missing newline
| // TODO: Support address space conversions? | ||
| // TODO: Support int trunc/ext? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: Maybe using a slicing utility like walkSlice would be beneficial for this?
|
I don't see any mention of non-integral pointers? |
nikic
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fold is incorrect, because the resulting pointer may have different provenance. We probably should not copy known-incorrect folds from LLVM to MLIR...
|
Yeah this transformation creates issues with provenance because for example int optimization may then decide to merge pointers with different provenance together, which can easily cause miscompiles. Would it be possible to do what you want without this transformation? What is the general thing you are trying to achieve? |
Ouch! I was afraid of provenance issue, and trusted InstCombine as a reference to assume this one is OK :( |
Thanks for the feedback. I'm not versed at all in pointer provenance and I'm lost if LLVM is not a valid reference implementation. Is there any way forward here? Would further checks (e.g., |
|
The canonical issue for this is #33896 -- though it's not particularly good reading due to its age, and is likely to leave a person more confused than before... It's generally hard to carve out valid subsets for this transform, in part because the precise way in which inttoptr recovers provenance is not settled. I think under some models (like angelic non-determinism over all exposed pointers) the transform is essentially always invalid. Ideally you avoid the issue in the first place by not using ptrtoint and inttoptr. Where are they coming from in your case? |
|
Thanks, that post was very helpful! Ok, I'm going to give up on this then.
There are different factors that might be addressable at higher levels of abstractions but I thought we would benefit from a simple cleanup at this level. I learnt something, though, so thanks a lot! |
This PR is a follow-up to #141891. It introduces a pass that can fold
inttoptr(ptrtoint(x)) -> xandptrtoint(inttoptr(x)) -> x. The pass takes in a list of address space bitwidths and makes sure that the folding is applied only when it's safe.