diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h new file mode 100644 index 0000000000000..62a9c84241a39 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h @@ -0,0 +1,43 @@ +//===- IntToPtrPtrToIntFolding.h - IntToPtr/PtrToInt folding ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a pass that folds inttoptr/ptrtoint operation sequences. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H +#define MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +class RewritePatternSet; + +namespace LLVM { + +#define GEN_PASS_DECL_FOLDINTTOPTRPTRTOINTPASS +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" + +/// Populate patterns that fold inttoptr/ptrtoint op sequences such as: +/// +/// * `inttoptr(ptrtoint(x))` -> `x` +/// * `ptrtoint(inttoptr(x))` -> `x` +/// +/// `addressSpaceBWs` contains the pointer bitwidth for each address space. If +/// the pointer bitwidth information is not available for a specific address +/// space, the folding for that address space is not performed. +/// +/// TODO: Support DLTI. +void populateIntToPtrPtrToIntFoldingPatterns( + RewritePatternSet &patterns, ArrayRef addressSpaceBWs); + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_INTTOPTRPTRTOINTFOLDING_H \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td index 961909d5c8d27..be45213e6b95e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td @@ -73,4 +73,23 @@ def DIScopeForLLVMFuncOpPass : Pass<"ensure-debug-info-scope-on-llvm-func", "::m ]; } +def FoldIntToPtrPtrToIntPass : Pass<"fold-llvm-inttoptr-ptrtoint", "LLVM::LLVMFuncOp"> { + let summary = "Fold inttoptr/ptrtoint operation sequences"; + let description = [{ + This pass folds sequences of inttoptr and ptrtoint operations that cancel + each other out. Specifically: + * inttoptr(ptrtoint(x)) -> x + * ptrtoint(inttoptr(x)) -> x + + The pass takes a sequence of address space bitwidths to make sure folding + is safe. If the bitwidth information is not available for an address space, + the pass will not fold any operations involving that address space. + }]; + let dependentDialects = ["LLVM::LLVMDialect"]; + let options = [ + ListOption<"addrSpaceBWs", "address-space-bitwidths", "unsigned", + "List of address space bitwidths sorted by associated index to each address space."> + ]; +} + #endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt index d4ff0955c5d0e..b22280718c454 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms DIExpressionRewriter.cpp DIScopeForLLVMFuncOp.cpp InlinerInterfaceImpl.cpp + IntToPtrPtrToIntFolding.cpp LegalizeForExport.cpp OptimizeForNVVM.cpp RequestCWrappers.cpp diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp new file mode 100644 index 0000000000000..c87a9cb2afe41 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.cpp @@ -0,0 +1,133 @@ +//===- IntToPtrPtrToIntFolding.cpp - IntToPtr/PtrToInt folding ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass that folds inttoptr/ptrtoint operation sequences. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/Transforms/IntToPtrPtrToIntFolding.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "fold-llvm-inttoptr-ptrtoint" + +namespace mlir { +namespace LLVM { + +#define GEN_PASS_DEF_FOLDINTTOPTRPTRTOINTPASS +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" + +} // namespace LLVM +} // namespace mlir + +using namespace mlir; + +namespace { + +/// Return the bitwidth of a pointer or integer type. If the type is a pointer, +/// return the bitwidth of the address space from `addrSpaceBWs`, if available. +/// Return failure if the address space bitwidth is not available. +static FailureOr getIntOrPtrBW(Type type, + ArrayRef addrSpaceBWs) { + if (auto ptrType = dyn_cast(type)) { + unsigned addrSpace = ptrType.getAddressSpace(); + if (addrSpace < addrSpaceBWs.size() && addrSpaceBWs[addrSpace] != 0) + return addrSpaceBWs[addrSpace]; + return failure(); + } + + auto integerType = cast(type); + return integerType.getWidth(); +} + +/// Check if folding inttoptr/ptrtoint is valid. Check that the original type +/// matches the result type of the end-to-end conversion and that the input +/// value is not truncated along the conversion chain. +static LogicalResult canFoldIntToPtrPtrToInt(Type originalType, + Type intermediateType, + Type resultType, + ArrayRef addrSpaceBWs) { + // Check if the original type matches the result type. + // TODO: Support address space conversions? + // TODO: Support int trunc/ext? + if (originalType != resultType) + return failure(); + + // Make sure there is no data truncation with respect to the original type at + // any point during the conversion. Truncating the intermediate data is fine + // as long as the original data is not truncated. + auto originalBW = getIntOrPtrBW(originalType, addrSpaceBWs); + if (failed(originalBW)) + return failure(); + + auto intermediateBW = getIntOrPtrBW(intermediateType, addrSpaceBWs); + if (failed(intermediateBW)) + return failure(); + + if (*originalBW > *intermediateBW) + return failure(); + return success(); +} + +/// Folds inttoptr(ptrtoint(x)) -> x or ptrtoint(inttoptr(x)) -> x. +template +class FoldIntToPtrPtrToInt : public OpRewritePattern { +public: + FoldIntToPtrPtrToInt(MLIRContext *context, ArrayRef addrSpaceBWs) + : OpRewritePattern(context), addrSpaceBWs(addrSpaceBWs) {} + + LogicalResult matchAndRewrite(DstConvOp dstConvOp, + PatternRewriter &rewriter) const override { + auto srcConvOp = dstConvOp.getArg().template getDefiningOp(); + if (!srcConvOp) + return failure(); + + // Check if folding is valid based on type matching and bitwidth + // information. + if (failed(canFoldIntToPtrPtrToInt(srcConvOp.getArg().getType(), + srcConvOp.getType(), dstConvOp.getType(), + addrSpaceBWs))) { + return failure(); + } + + rewriter.replaceOp(dstConvOp, srcConvOp.getArg()); + return success(); + } + +private: + SmallVector addrSpaceBWs; +}; + +/// Pass that folds inttoptr/ptrtoint operation sequences. +struct FoldIntToPtrPtrToIntPass + : public LLVM::impl::FoldIntToPtrPtrToIntPassBase< + FoldIntToPtrPtrToIntPass> { + using Base = + LLVM::impl::FoldIntToPtrPtrToIntPassBase; + using Base::FoldIntToPtrPtrToIntPassBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + LLVM::populateIntToPtrPtrToIntFoldingPatterns(patterns, addrSpaceBWs); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +void mlir::LLVM::populateIntToPtrPtrToIntFoldingPatterns( + RewritePatternSet &patterns, ArrayRef addrSpaceBWs) { + patterns.add, + FoldIntToPtrPtrToInt>( + patterns.getContext(), addrSpaceBWs); +} diff --git a/mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir b/mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir new file mode 100644 index 0000000000000..30193eff93be2 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/inttoptr-ptrtoint-folding.mlir @@ -0,0 +1,100 @@ +// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=64}))" %s | FileCheck %s --check-prefixes=CHECK-64BIT,CHECK-ALL +// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=32}))" %s | FileCheck %s --check-prefixes=CHECK-32BIT,CHECK-ALL +// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint{address-space-bitwidths=64,32}))" %s | FileCheck %s --check-prefixes=CHECK-MULTI-ADDRSPACE,CHECK-ALL +// RUN: mlir-opt -pass-pipeline="builtin.module(llvm.func(fold-llvm-inttoptr-ptrtoint))" %s | FileCheck %s --check-prefixes=CHECK-DISABLED,CHECK-ALL + + +// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_fold_64bit +// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr) +llvm.func @test_inttoptr_ptrtoint_fold_64bit(%arg0: !llvm.ptr) -> !llvm.ptr { + // CHECK-64BIT-NOT: llvm.ptrtoint + // CHECK-64BIT-NOT: llvm.inttoptr + // CHECK-64BIT: llvm.return %[[ARG]] + + // CHECK-32BIT-NOT: llvm.ptrtoint + // CHECK-32BIT-NOT: llvm.inttoptr + // CHECK-32BIT: llvm.return %[[ARG]] + + // CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint + // CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr + // CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]] + + // CHECK-DISABLED: %[[INT:.+]] = llvm.ptrtoint %[[ARG]] + // CHECK-DISABLED: %[[PTR:.+]] = llvm.inttoptr %[[INT]] + // CHECK-DISABLED: llvm.return %[[PTR]] + + %0 = llvm.ptrtoint %arg0 : !llvm.ptr to i64 + %1 = llvm.inttoptr %0 : i64 to !llvm.ptr + llvm.return %1 : !llvm.ptr +} + +// CHECK-ALL-LABEL: @test_ptrtoint_inttoptr_fold_64bit +// CHECK-ALL-SAME: (%[[ARG:.+]]: i64) +llvm.func @test_ptrtoint_inttoptr_fold_64bit(%arg0: i64) -> i64 { + // CHECK-64BIT-NOT: llvm.inttoptr + // CHECK-64BIT-NOT: llvm.ptrtoint + // CHECK-64BIT: llvm.return %[[ARG]] + + // CHECK-32BIT: %[[INT:.+]] = llvm.inttoptr %[[ARG]] + // CHECK-32BIT: %[[PTR:.+]] = llvm.ptrtoint %[[INT]] + // CHECK-32BIT: llvm.return %[[PTR]] + + // CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr + // CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint + // CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]] + + // CHECK-DISABLED: %[[INT:.+]] = llvm.inttoptr %[[ARG]] + // CHECK-DISABLED: %[[PTR:.+]] = llvm.ptrtoint %[[INT]] + // CHECK-DISABLED: llvm.return %[[PTR]] + + %0 = llvm.inttoptr %arg0 : i64 to !llvm.ptr + %1 = llvm.ptrtoint %0 : !llvm.ptr to i64 + llvm.return %1 : i64 +} + +// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_fold_addrspace1_32bit +// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr<1>) +llvm.func @test_inttoptr_ptrtoint_fold_addrspace1_32bit(%arg0: !llvm.ptr<1>) -> !llvm.ptr<1> { + // CHECK-64BIT: %[[INT:.+]] = llvm.ptrtoint %[[ARG]] + // CHECK-64BIT: %[[PTR:.+]] = llvm.inttoptr %[[INT]] + // CHECK-64BIT: llvm.return %[[PTR]] + + // CHECK-32BIT: %[[INT:.+]] = llvm.ptrtoint %[[ARG]] + // CHECK-32BIT: %[[PTR:.+]] = llvm.inttoptr %[[INT]] + // CHECK-32BIT: llvm.return %[[PTR]] + + // CHECK-MULTI-ADDRSPACE-NOT: llvm.ptrtoint + // CHECK-MULTI-ADDRSPACE-NOT: llvm.inttoptr + // CHECK-MULTI-ADDRSPACE: llvm.return %[[ARG]] + + // CHECK-DISABLED: %[[INT:.+]] = llvm.ptrtoint %[[ARG]] + // CHECK-DISABLED: %[[PTR:.+]] = llvm.inttoptr %[[INT]] + // CHECK-DISABLED: llvm.return %[[PTR]] + + %0 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i32 + %1 = llvm.inttoptr %0 : i32 to !llvm.ptr<1> + llvm.return %1 : !llvm.ptr<1> +} + +// CHECK-ALL-LABEL: @test_inttoptr_ptrtoint_type_mismatch +// CHECK-ALL-SAME: (%[[ARG:.+]]: i64) +llvm.func @test_inttoptr_ptrtoint_type_mismatch(%arg0: i64) -> i32 { + // CHECK-ALL: %[[INT:.+]] = llvm.inttoptr %[[ARG]] + // CHECK-ALL: %[[PTR:.+]] = llvm.ptrtoint %[[INT]] + // CHECK-ALL: llvm.return %[[PTR]] + + %0 = llvm.inttoptr %arg0 : i64 to !llvm.ptr + %1 = llvm.ptrtoint %0 : !llvm.ptr to i32 + llvm.return %1 : i32 +} + +// CHECK-ALL-LABEL: @test_ptrtoint_inttoptr_type_mismatch +// CHECK-ALL-SAME: (%[[ARG:.+]]: !llvm.ptr<1>) +llvm.func @test_ptrtoint_inttoptr_type_mismatch(%arg0: !llvm.ptr<1>) -> !llvm.ptr<0> { + // CHECK-ALL: %[[INT:.+]] = llvm.ptrtoint %[[ARG]] + // CHECK-ALL: %[[PTR:.+]] = llvm.inttoptr %[[INT]] + // CHECK-ALL: llvm.return %[[PTR]] + %0 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64 + %1 = llvm.inttoptr %0 : i64 to !llvm.ptr<0> + llvm.return %1 : !llvm.ptr<0> +}