Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//===- IntToPtrPtrToIntFolding.h - IntToPtr/PtrToInt folding ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares a pass that folds inttoptr/ptrtoint operation sequences.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
#define MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
class RewritePatternSet;

namespace LLVM {

#define GEN_PASS_DECL_FOLDINTTOPTRPTRTOINTPASS
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"

/// Populate patterns that fold inttoptr/ptrtoint op sequences such as:
///
/// * `inttoptr(ptrtoint(x))` -> `x`
/// * `ptrtoint(inttoptr(x))` -> `x`
///
/// `addressSpaceBWs` contains the pointer bitwidth for each address space. If
/// the pointer bitwidth information is not available for a specific address
/// space, the folding for that address space is not performed.
///
/// TODO: Support DLTI.
void populateIntToPtrPtrToIntFoldingPatterns(
RewritePatternSet &patterns, ArrayRef<unsigned> addressSpaceBWs);

} // namespace LLVM
} // namespace mlir

#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Missing newline

19 changes: 19 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,23 @@ def DIScopeForLLVMFuncOpPass : Pass<"ensure-debug-info-scope-on-llvm-func", "::m
];
}

def FoldIntToPtrPtrToIntPass : Pass<"fold-llvm-inttoptr-ptrtoint", "LLVM::LLVMFuncOp"> {
let summary = "Fold inttoptr/ptrtoint operation sequences";
let description = [{
This pass folds sequences of inttoptr and ptrtoint operations that cancel
each other out. Specifically:
* inttoptr(ptrtoint(x)) -> x
* ptrtoint(inttoptr(x)) -> x

The pass takes a sequence of address space bitwidths to make sure folding
is safe. If the bitwidth information is not available for an address space,
the pass will not fold any operations involving that address space.
}];
let dependentDialects = ["LLVM::LLVMDialect"];
let options = [
ListOption<"addrSpaceBWs", "address-space-bitwidths", "unsigned",
"List of address space bitwidths sorted by associated index to each address space.">
];
}

#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms
DIExpressionRewriter.cpp
DIScopeForLLVMFuncOp.cpp
InlinerInterfaceImpl.cpp
IntToPtrPtrToIntFolding.cpp
LegalizeForExport.cpp
OptimizeForNVVM.cpp
RequestCWrappers.cpp
Expand Down
133 changes: 133 additions & 0 deletions mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
//===- IntToPtrPtrToIntFolding.cpp - IntToPtr/PtrToInt folding ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass that folds inttoptr/ptrtoint operation sequences.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h"

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "fold-llvm-inttoptr-ptrtoint"

namespace mlir {
namespace LLVM {

#define GEN_PASS_DEF_FOLDINTTOPTRPTRTOINTPASS
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"

} // namespace LLVM
} // namespace mlir

using namespace mlir;

namespace {

/// Return the bitwidth of a pointer or integer type. If the type is a pointer,
/// return the bitwidth of the address space from `addrSpaceBWs`, if available.
/// Return failure if the address space bitwidth is not available.
static FailureOr<unsigned> getIntOrPtrBW(Type type,
ArrayRef<unsigned> addrSpaceBWs) {
if (auto ptrType = dyn_cast<LLVM::LLVMPointerType>(type)) {
unsigned addrSpace = ptrType.getAddressSpace();
if (addrSpace < addrSpaceBWs.size() && addrSpaceBWs[addrSpace] != 0)
return addrSpaceBWs[addrSpace];
return failure();
}

auto integerType = cast<IntegerType>(type);
return integerType.getWidth();
}

/// Check if folding inttoptr/ptrtoint is valid. Check that the original type
/// matches the result type of the end-to-end conversion and that the input
/// value is not truncated along the conversion chain.
static LogicalResult canFoldIntToPtrPtrToInt(Type originalType,
Type intermediateType,
Type resultType,
ArrayRef<unsigned> addrSpaceBWs) {
// Check if the original type matches the result type.
// TODO: Support address space conversions?
// TODO: Support int trunc/ext?
Comment on lines +59 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Maybe using a slicing utility like walkSlice would be beneficial for this?

if (originalType != resultType)
return failure();

// Make sure there is no data truncation with respect to the original type at
// any point during the conversion. Truncating the intermediate data is fine
// as long as the original data is not truncated.
auto originalBW = getIntOrPtrBW(originalType, addrSpaceBWs);
if (failed(originalBW))
return failure();

auto intermediateBW = getIntOrPtrBW(intermediateType, addrSpaceBWs);
if (failed(intermediateBW))
return failure();

if (*originalBW > *intermediateBW)
return failure();
return success();
}

/// Folds inttoptr(ptrtoint(x)) -> x or ptrtoint(inttoptr(x)) -> x.
template <typename SrcConvOp, typename DstConvOp>
class FoldIntToPtrPtrToInt : public OpRewritePattern<DstConvOp> {
public:
FoldIntToPtrPtrToInt(MLIRContext *context, ArrayRef<unsigned> addrSpaceBWs)
: OpRewritePattern<DstConvOp>(context), addrSpaceBWs(addrSpaceBWs) {}

LogicalResult matchAndRewrite(DstConvOp dstConvOp,
PatternRewriter &rewriter) const override {
auto srcConvOp = dstConvOp.getArg().template getDefiningOp<SrcConvOp>();
if (!srcConvOp)
return failure();

// Check if folding is valid based on type matching and bitwidth
// information.
if (failed(canFoldIntToPtrPtrToInt(srcConvOp.getArg().getType(),
srcConvOp.getType(), dstConvOp.getType(),
addrSpaceBWs))) {
return failure();
}

rewriter.replaceOp(dstConvOp, srcConvOp.getArg());
return success();
}

private:
SmallVector<unsigned> addrSpaceBWs;
};

/// Pass that folds inttoptr/ptrtoint operation sequences.
struct FoldIntToPtrPtrToIntPass
: public LLVM::impl::FoldIntToPtrPtrToIntPassBase<
FoldIntToPtrPtrToIntPass> {
using Base =
LLVM::impl::FoldIntToPtrPtrToIntPassBase<FoldIntToPtrPtrToIntPass>;
using Base::FoldIntToPtrPtrToIntPassBase;

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
LLVM::populateIntToPtrPtrToIntFoldingPatterns(patterns, addrSpaceBWs);

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};

} // namespace

void mlir::LLVM::populateIntToPtrPtrToIntFoldingPatterns(
RewritePatternSet &patterns, ArrayRef<unsigned> addrSpaceBWs) {
patterns.add<FoldIntToPtrPtrToInt<LLVM::PtrToIntOp, LLVM::IntToPtrOp>,
FoldIntToPtrPtrToInt<LLVM::IntToPtrOp, LLVM::PtrToIntOp>>(
patterns.getContext(), addrSpaceBWs);
}
100 changes: 100 additions & 0 deletions mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=64}))" %s | FileCheck %s --check-prefixes=CHECK-64BIT,CHECK-ALL
// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=32}))" %s | FileCheck %s --check-prefixes=CHECK-32BIT,CHECK-ALL
// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=64,32}))" %s | FileCheck %s --check-prefixes=CHECK-MULTI-ADDRSPACE,CHECK-ALL
// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint))" %s | FileCheck %s --check-prefixes=CHECK-DISABLED,CHECK-ALL


// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_fold_64bit
// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr)
llvm.func @test_inttoptr_ptrtoint_fold_64bit(%arg0: !llvm.ptr) -> !llvm.ptr {
// CHECK-64BIT-NOT: llvm.ptrtoint
// CHECK-64BIT-NOT: llvm.inttoptr
// CHECK-64BIT: llvm.return %[[ARG]]

// CHECK-32BIT-NOT: llvm.ptrtoint
// CHECK-32BIT-NOT: llvm.inttoptr
// CHECK-32BIT: llvm.return %[[ARG]]

// CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
// CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
// CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]

// CHECK-DISABLED: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
// CHECK-DISABLED: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
// CHECK-DISABLED: llvm.return %[[PTR]]

%0 = llvm.ptrtoint %arg0 : !llvm.ptr to i64
%1 = llvm.inttoptr %0 : i64 to !llvm.ptr
llvm.return %1 : !llvm.ptr
}

// CHECK-ALL-LABEL: @test_ptrtoint_inttoptr_fold_64bit
// CHECK-ALL-SAME: (%[[ARG:.+]]: i64)
llvm.func @test_ptrtoint_inttoptr_fold_64bit(%arg0: i64) -> i64 {
// CHECK-64BIT-NOT: llvm.inttoptr
// CHECK-64BIT-NOT: llvm.ptrtoint
// CHECK-64BIT: llvm.return %[[ARG]]

// CHECK-32BIT: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
// CHECK-32BIT: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
// CHECK-32BIT: llvm.return %[[PTR]]

// CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
// CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
// CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]

// CHECK-DISABLED: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
// CHECK-DISABLED: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
// CHECK-DISABLED: llvm.return %[[PTR]]

%0 = llvm.inttoptr %arg0 : i64 to !llvm.ptr
%1 = llvm.ptrtoint %0 : !llvm.ptr to i64
llvm.return %1 : i64
}

// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_fold_addrspace1_32bit
// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr<1>)
llvm.func @test_inttoptr_ptrtoint_fold_addrspace1_32bit(%arg0: !llvm.ptr<1>) -> !llvm.ptr<1> {
// CHECK-64BIT: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
// CHECK-64BIT: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
// CHECK-64BIT: llvm.return %[[PTR]]

// CHECK-32BIT: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
// CHECK-32BIT: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
// CHECK-32BIT: llvm.return %[[PTR]]

// CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint
// CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr
// CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]]

// CHECK-DISABLED: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
// CHECK-DISABLED: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
// CHECK-DISABLED: llvm.return %[[PTR]]

%0 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i32
%1 = llvm.inttoptr %0 : i32 to !llvm.ptr<1>
llvm.return %1 : !llvm.ptr<1>
}

// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_type_mismatch
// CHECK-ALL-SAME: (%[[ARG:.+]]: i64)
llvm.func @test_inttoptr_ptrtoint_type_mismatch(%arg0: i64) -> i32 {
// CHECK-ALL: %[[INT:.+]] = llvm.inttoptr %[[ARG]]
// CHECK-ALL: %[[PTR:.+]] = llvm.ptrtoint %[[INT]]
// CHECK-ALL: llvm.return %[[PTR]]

%0 = llvm.inttoptr %arg0 : i64 to !llvm.ptr
%1 = llvm.ptrtoint %0 : !llvm.ptr to i32
llvm.return %1 : i32
}

// CHECK-ALL-LABEL: @test_ptrtoint_inttoptr_type_mismatch
// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr<1>)
llvm.func @test_ptrtoint_inttoptr_type_mismatch(%arg0: !llvm.ptr<1>) -> !llvm.ptr<0> {
// CHECK-ALL: %[[INT:.+]] = llvm.ptrtoint %[[ARG]]
// CHECK-ALL: %[[PTR:.+]] = llvm.inttoptr %[[INT]]
// CHECK-ALL: llvm.return %[[PTR]]
%0 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
%1 = llvm.inttoptr %0 : i64 to !llvm.ptr<0>
llvm.return %1 : !llvm.ptr<0>
}
Loading