Skip to content

Commit 23924b9

Browse files
committed
[mlir][LLVMIR] Add folder pass for llvm.inttoptr and llvm.ptrtoint
This PR is a follow-up to llvm#141891. It introduces a pass that can fold `inttoptr(ptrtoint(x)) -> x` and `ptrtoint(inttoptr(x)) -> x`. The pass takes in a list of address space bitwidths and makes sure that the folding is applied only when it's safe.
1 parent f532167 commit 23924b9

File tree

5 files changed

+296
-0
lines changed

5 files changed

+296
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//===- IntToPtrPtrToIntFolding.h - IntToPtr/PtrToInt folding ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares a pass that folds inttoptr/ptrtoint operation sequences.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
14+
#define MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
15+
16+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17+
#include "mlir/Pass/Pass.h"
18+
19+
namespace mlir {
20+
class RewritePatternSet;
21+
22+
namespace LLVM {
23+
24+
#define GEN_PASS_DECL_FOLDINTTOPTRPTRTOINTPASS
25+
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
26+
27+
/// Populate patterns that fold inttoptr/ptrtoint op sequences such as:
28+
///
29+
/// * `inttoptr(ptrtoint(x))` -> `x`
30+
/// * `ptrtoint(inttoptr(x))` -> `x`
31+
///
32+
/// `addressSpaceBWs` contains the pointer bitwidth for each address space. If
33+
/// the pointer bitwidth information is not available for a specific address
34+
/// space, the folding for that address space is not performed.
35+
///
36+
/// TODO: Support DLTI.
37+
void populateIntToPtrPtrToIntFoldingPatterns(
38+
RewritePatternSet &patterns, ArrayRef<unsigned> addressSpaceBWs);
39+
40+
} // namespace LLVM
41+
} // namespace mlir
42+
43+
#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H

mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,23 @@ def DIScopeForLLVMFuncOpPass : Pass<"ensure-debug-info-scope-on-llvm-func", "::m
7373
];
7474
}
7575

76+
def FoldIntToPtrPtrToIntPass : Pass<"fold-llvm-inttoptr-ptrtoint", "LLVM::LLVMFuncOp"> {
77+
let summary = "Fold inttoptr/ptrtoint operation sequences";
78+
let description = [{
79+
This pass folds sequences of inttoptr and ptrtoint operations that cancel
80+
each other out. Specifically:
81+
* inttoptr(ptrtoint(x)) -> x
82+
* ptrtoint(inttoptr(x)) -> x
83+
84+
The pass takes a sequence of address space bitwidths to make sure folding
85+
is safe. If the bitwidth information is not available for an address space,
86+
the pass will not fold any operations involving that address space.
87+
}];
88+
let dependentDialects = ["LLVM::LLVMDialect"];
89+
let options = [
90+
ListOption<"addrSpaceBWs", "address-space-bitwidths", "unsigned",
91+
"List of address space bitwidths sorted by associated index to each address space.">
92+
];
93+
}
94+
7695
#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES

mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms
44
DIExpressionRewriter.cpp
55
DIScopeForLLVMFuncOp.cpp
66
InlinerInterfaceImpl.cpp
7+
IntToPtrPtrToIntFolding.cpp
78
LegalizeForExport.cpp
89
OptimizeForNVVM.cpp
910
RequestCWrappers.cpp
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
//===- IntToPtrPtrToIntFolding.cpp - IntToPtr/PtrToInt folding ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass that folds inttoptr/ptrtoint operation sequences.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h"
14+
15+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/Support/LogicalResult.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
20+
#define DEBUG_TYPE "fold-llvm-inttoptr-ptrtoint"
21+
22+
namespace mlir {
23+
namespace LLVM {
24+
25+
#define GEN_PASS_DEF_FOLDINTTOPTRPTRTOINTPASS
26+
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
27+
28+
} // namespace LLVM
29+
} // namespace mlir
30+
31+
using namespace mlir;
32+
33+
namespace {
34+
35+
/// Return the bitwidth of a pointer or integer type. If the type is a pointer,
36+
/// return the bitwidth of the address space from `addrSpaceBWs`, if available.
37+
/// Return failure if the address space bitwidth is not available.
38+
static FailureOr<unsigned> getIntOrPtrBW(Type type,
39+
ArrayRef<unsigned> addrSpaceBWs) {
40+
if (auto ptrType = dyn_cast<LLVM::LLVMPointerType>(type)) {
41+
unsigned addrSpace = ptrType.getAddressSpace();
42+
if (addrSpace < addrSpaceBWs.size() && addrSpaceBWs[addrSpace] != 0)
43+
return addrSpaceBWs[addrSpace];
44+
return failure();
45+
}
46+
47+
auto integerType = cast<IntegerType>(type);
48+
return integerType.getWidth();
49+
}
50+
51+
/// Check if folding inttoptr/ptrtoint is valid. Check that the original type
52+
/// matches the result type of the end-to-end conversion and that the input
53+
/// value is not truncated along the conversion chain.
54+
static LogicalResult canFoldIntToPtrPtrToInt(Type originalType,
55+
Type intermediateType,
56+
Type resultType,
57+
ArrayRef<unsigned> addrSpaceBWs) {
58+
// Check if the original type matches the result type.
59+
// TODO: Support address space conversions?
60+
// TODO: Support int trunc/ext?
61+
if (originalType != resultType)
62+
return failure();
63+
64+
// Make sure there is no data truncation with respect to the original type at
65+
// any point during the conversion. Truncating the intermediate data is fine
66+
// as long as the original data is not truncated.
67+
auto originalBW = getIntOrPtrBW(originalType, addrSpaceBWs);
68+
if (failed(originalBW))
69+
return failure();
70+
71+
auto intermediateBW = getIntOrPtrBW(intermediateType, addrSpaceBWs);
72+
if (failed(intermediateBW))
73+
return failure();
74+
75+
if (*originalBW > *intermediateBW)
76+
return failure();
77+
return success();
78+
}
79+
80+
/// Folds inttoptr(ptrtoint(x)) -> x or ptrtoint(inttoptr(x)) -> x.
81+
template <typename SrcConvOp, typename DstConvOp>
82+
class FoldIntToPtrPtrToInt : public OpRewritePattern<DstConvOp> {
83+
public:
84+
FoldIntToPtrPtrToInt(MLIRContext *context, ArrayRef<unsigned> addrSpaceBWs)
85+
: OpRewritePattern<DstConvOp>(context), addrSpaceBWs(addrSpaceBWs) {}
86+
87+
LogicalResult matchAndRewrite(DstConvOp dstConvOp,
88+
PatternRewriter &rewriter) const override {
89+
auto srcConvOp = dstConvOp.getArg().template getDefiningOp<SrcConvOp>();
90+
if (!srcConvOp)
91+
return failure();
92+
93+
// Check if folding is valid based on type matching and bitwidth
94+
// information.
95+
if (failed(canFoldIntToPtrPtrToInt(srcConvOp.getArg().getType(),
96+
srcConvOp.getType(), dstConvOp.getType(),
97+
addrSpaceBWs))) {
98+
return failure();
99+
}
100+
101+
rewriter.replaceOp(dstConvOp, srcConvOp.getArg());
102+
return success();
103+
}
104+
105+
private:
106+
SmallVector<unsigned> addrSpaceBWs;
107+
};
108+
109+
/// Pass that folds inttoptr/ptrtoint operation sequences.
110+
struct FoldIntToPtrPtrToIntPass
111+
: public LLVM::impl::FoldIntToPtrPtrToIntPassBase<
112+
FoldIntToPtrPtrToIntPass> {
113+
using Base =
114+
LLVM::impl::FoldIntToPtrPtrToIntPassBase<FoldIntToPtrPtrToIntPass>;
115+
using Base::FoldIntToPtrPtrToIntPassBase;
116+
117+
void runOnOperation() override {
118+
RewritePatternSet patterns(&getContext());
119+
LLVM::populateIntToPtrPtrToIntFoldingPatterns(patterns, addrSpaceBWs);
120+
121+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
122+
signalPassFailure();
123+
}
124+
};
125+
126+
} // namespace
127+
128+
void mlir::LLVM::populateIntToPtrPtrToIntFoldingPatterns(
129+
RewritePatternSet &patterns, ArrayRef<unsigned> addrSpaceBWs) {
130+
patterns.add<FoldIntToPtrPtrToInt<LLVM::PtrToIntOp, LLVM::IntToPtrOp>,
131+
FoldIntToPtrPtrToInt<LLVM::IntToPtrOp, LLVM::PtrToIntOp>>(
132+
patterns.getContext(), addrSpaceBWs);
133+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=64}))" %s | FileCheck %s --check-prefixes=CHECK-64BIT,CHECK-ALL
2+
// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=32}))" %s | FileCheck %s --check-prefixes=CHECK-32BIT,CHECK-ALL
3+
// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=64,32}))" %s | FileCheck %s --check-prefixes=CHECK-MULTI-ADDRSPACE,CHECK-ALL
4+
// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint))" %s | FileCheck %s --check-prefixes=CHECK-DISABLED,CHECK-ALL
5+
6+
7+
// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_fold_64bit
8+
// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr)
9+
llvm.func @test_inttoptr_ptrtoint_fold_64bit(%arg0: !llvm.ptr) -> !llvm.ptr {
10+
// CHECK-64BIT-NOT: llvm.ptrtoint
11+
// CHECK-64BIT-NOT: llvm.inttoptr
12+
// CHECK-64BIT: llvm.return %[[ARG]]
13+
14+
// CHECK-32BIT-NOT: llvm.ptrtoint
15+
// CHECK-32BIT-NOT: llvm.inttoptr
16+
// CHECK-32BIT: llvm.return %[[ARG]]
17+
18+
// CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
19+
// CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
20+
// CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]
21+
22+
// CHECK-DISABLED: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
23+
// CHECK-DISABLED: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
24+
// CHECK-DISABLED: llvm.return %[[PTR]]
25+
26+
%0 = llvm.ptrtoint %arg0 : !llvm.ptr to i64
27+
%1 = llvm.inttoptr %0 : i64 to !llvm.ptr
28+
llvm.return %1 : !llvm.ptr
29+
}
30+
31+
// CHECK-ALL-LABEL: @test_ptrtoint_inttoptr_fold_64bit
32+
// CHECK-ALL-SAME: (%[[ARG:.+]]: i64)
33+
llvm.func @test_ptrtoint_inttoptr_fold_64bit(%arg0: i64) -> i64 {
34+
// CHECK-64BIT-NOT: llvm.inttoptr
35+
// CHECK-64BIT-NOT: llvm.ptrtoint
36+
// CHECK-64BIT: llvm.return %[[ARG]]
37+
38+
// CHECK-32BIT: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
39+
// CHECK-32BIT: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
40+
// CHECK-32BIT: llvm.return %[[PTR]]
41+
42+
// CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
43+
// CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
44+
// CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]
45+
46+
// CHECK-DISABLED: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
47+
// CHECK-DISABLED: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
48+
// CHECK-DISABLED: llvm.return %[[PTR]]
49+
50+
%0 = llvm.inttoptr %arg0 : i64 to !llvm.ptr
51+
%1 = llvm.ptrtoint %0 : !llvm.ptr to i64
52+
llvm.return %1 : i64
53+
}
54+
55+
// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_fold_addrspace1_32bit
56+
// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr<1>)
57+
llvm.func @test_inttoptr_ptrtoint_fold_addrspace1_32bit(%arg0: !llvm.ptr<1>) -> !llvm.ptr<1> {
58+
// CHECK-64BIT: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
59+
// CHECK-64BIT: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
60+
// CHECK-64BIT: llvm.return %[[PTR]]
61+
62+
// CHECK-32BIT: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
63+
// CHECK-32BIT: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
64+
// CHECK-32BIT: llvm.return %[[PTR]]
65+
66+
// CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
67+
// CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
68+
// CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]
69+
70+
// CHECK-DISABLED: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
71+
// CHECK-DISABLED: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
72+
// CHECK-DISABLED: llvm.return %[[PTR]]
73+
74+
%0 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i32
75+
%1 = llvm.inttoptr %0 : i32 to !llvm.ptr<1>
76+
llvm.return %1 : !llvm.ptr<1>
77+
}
78+
79+
// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_type_mismatch
80+
// CHECK-ALL-SAME: (%[[ARG:.+]]: i64)
81+
llvm.func @test_inttoptr_ptrtoint_type_mismatch(%arg0: i64) -> i32 {
82+
// CHECK-ALL: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
83+
// CHECK-ALL: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
84+
// CHECK-ALL: llvm.return %[[PTR]]
85+
86+
%0 = llvm.inttoptr %arg0 : i64 to !llvm.ptr
87+
%1 = llvm.ptrtoint %0 : !llvm.ptr to i32
88+
llvm.return %1 : i32
89+
}
90+
91+
// CHECK-ALL-LABEL: @test_ptrtoint_inttoptr_type_mismatch
92+
// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr<1>)
93+
llvm.func @test_ptrtoint_inttoptr_type_mismatch(%arg0: !llvm.ptr<1>) -> !llvm.ptr<0> {
94+
// CHECK-ALL: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
95+
// CHECK-ALL: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
96+
// CHECK-ALL: llvm.return %[[PTR]]
97+
%0 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
98+
%1 = llvm.inttoptr %0 : i64 to !llvm.ptr<0>
99+
llvm.return %1 : !llvm.ptr<0>
100+
}

0 commit comments

Comments
 (0)