Skip to content

Commit 0301bc3

Browse files
[Transforms] lower memref.copy (#507)
* [Transforms] lower memref.copy * address review issues
1 parent 5369bc8 commit 0301bc3

File tree

6 files changed

+149
-0
lines changed

6 files changed

+149
-0
lines changed

include/imex/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ std::unique_ptr<mlir::Pass> createInsertGPUAllocsPass();
2525
std::unique_ptr<mlir::Pass> createSetSPIRVCapabilitiesPass();
2626
std::unique_ptr<mlir::Pass> createSetSPIRVAbiAttributePass();
2727
std::unique_ptr<mlir::Pass> createAddOuterParallelLoopPass();
28+
std::unique_ptr<mlir::Pass> createLowerMemRefCopyPass();
2829

2930
#define GEN_PASS_DECL
3031
#include "imex/Transforms/Passes.h.inc"

include/imex/Transforms/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,19 @@ def AddOuterParallelLoop : Pass<"imex-add-outer-parallel-loop", "::mlir::func::F
7575
];
7676
}
7777

78+
def LowerMemRefCopy : Pass<"imex-lower-memref-copy", "::mlir::func::FuncOp"> {
79+
let summary = "lower memref.copy to linalg.generic";
80+
let description = [{
81+
This Pass transforms memref.copy to linalg.generic with identity index map and
82+
parallel iterator. If satisfied, this pass also does memref.copy canonicalization.
83+
84+
This pass is supposed to work after bufferization and before linalg-lowering.
85+
}];
86+
let constructor = "imex::createLowerMemRefCopyPass()";
87+
let dependentDialects = [
88+
"::mlir::linalg::LinalgDialect",
89+
"::mlir::memref::MemRefDialect"
90+
];
91+
}
92+
7893
#endif // _IMEX_TRANSFORMS_PASSES_TD_INCLUDED_

lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_library(IMEXTransforms
44
SetSPIRVCapabilities.cpp
55
SetSPIRVAbiAttribute.cpp
66
AddOuterParallelLoop.cpp
7+
LowerMemRefCopy.cpp
78

89
ADDITIONAL_HEADER_DIRS
910
${PROJECT_SOURCE_DIR}/imex/Transforms

lib/Transforms/LowerMemRefCopy.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//===- LowerMemRefCopy.cpp - lower memref.copy pass --------*- C++ -*-===//
2+
//
3+
// Copyright 2022 Intel Corporation
4+
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
///
10+
/// \file
11+
/// This pass lowers memref copyOp to linalg generic operations and enables
12+
/// simple memref copyOp canonicalization
13+
///
14+
//===----------------------------------------------------------------------===//
15+
16+
#include "imex/Transforms/Passes.h"
17+
18+
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
20+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
21+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
22+
#include "mlir/IR/Dominance.h"
23+
#include "mlir/Pass/Pass.h"
24+
#include "mlir/Support/LogicalResult.h"
25+
26+
namespace imex {
27+
#define GEN_PASS_DEF_LOWERMEMREFCOPY
28+
#include "imex/Transforms/Passes.h.inc"
29+
} // namespace imex
30+
31+
using namespace mlir;
32+
using namespace imex;
33+
34+
namespace {
35+
struct LowerMemRefCopy
36+
: public imex::impl::LowerMemRefCopyBase<LowerMemRefCopy> {
37+
void runOnOperation() override {
38+
auto &domInfo = getAnalysis<DominanceInfo>();
39+
auto func = getOperation();
40+
// walk through memref.copy ops in the funcion body
41+
WalkResult result =
42+
func.walk<WalkOrder::PreOrder>([&](memref::CopyOp op) -> WalkResult {
43+
if (op->getParentOp() != func)
44+
return WalkResult::skip();
45+
auto src = op.getSource();
46+
auto dst = op.getTarget();
47+
// supposed to work on same memref type
48+
auto srcType = src.getType().cast<MemRefType>();
49+
auto dstType = dst.getType().cast<MemRefType>();
50+
if (srcType != dstType)
51+
return WalkResult::skip();
52+
// supposed to work on memref.alloc
53+
auto srcOp = src.getDefiningOp<memref::AllocOp>();
54+
auto dstOp = dst.getDefiningOp<memref::AllocOp>();
55+
if (!srcOp || !dstOp)
56+
return WalkResult::skip();
57+
// check use of src after this copyOp, being conservative
58+
// FIXME: handle dealloc of src and dst
59+
bool hasSubsequentUse = false;
60+
for (auto user : src.getUsers()) {
61+
if (isa<memref::DeallocOp>(user)) {
62+
continue;
63+
}
64+
if (domInfo.properlyDominates(op, user)) {
65+
hasSubsequentUse = true;
66+
break;
67+
}
68+
}
69+
70+
// replace copy with linalg.generic
71+
if (hasSubsequentUse) {
72+
OpBuilder builder(op);
73+
linalg::makeMemRefCopyOp(builder, op.getLoc(), src, dst);
74+
} else {
75+
// coalesce buffer
76+
dst.replaceAllUsesWith(src);
77+
}
78+
op.erase();
79+
return WalkResult::advance();
80+
});
81+
}
82+
};
83+
} // namespace
84+
85+
namespace imex {
86+
std::unique_ptr<mlir::Pass> createLowerMemRefCopyPass() {
87+
return std::make_unique<LowerMemRefCopy>();
88+
}
89+
} // namespace imex

lib/Transforms/PassDetail.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class SPIRVDialect;
3030
namespace scf {
3131
class SCFDialect;
3232
}
33+
34+
namespace linalg {
35+
class LinalgDialect;
36+
}
3337
} // end namespace mlir
3438

3539
#include "mlir/Dialect/Affine/IR/AffineOps.h"
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: imex-opt -imex-lower-memref-copy -allow-unregistered-dialect %s | FileCheck %s
2+
#map = affine_map<(d0, d1) -> (d0, d1)>
3+
module {
4+
func.func @copy_with_later_use(%arg0: memref<10x20xf32>) -> memref<10x20xf32> {
5+
%cst = arith.constant 0.000000e+00 : f32
6+
%alloc = memref.alloc() {alignment = 128 : i64} : memref<10x20xf32>
7+
%alloc_0 = memref.alloc() {alignment = 128 : i64} : memref<10x20xf32>
8+
linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<10x20xf32>)
9+
%alloc_1 = memref.alloc() {alignment = 128 : i64} : memref<10x20xf32>
10+
memref.copy %alloc_0, %alloc_1 : memref<10x20xf32> to memref<10x20xf32>
11+
"some_use" (%alloc_0) {} : (memref<10x20xf32>) -> ()
12+
// CHECK-LABEL: func @copy_with_later_use
13+
// CHECK: linalg.generic
14+
// CHECK: linalg.generic
15+
linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : memref<10x20xf32>) outs(%alloc_1 : memref<10x20xf32>) attrs = {iterator_ranges = [10, 20]} {
16+
^bb0(%in: f32, %out: f32):
17+
%0 = arith.addf %out, %in : f32
18+
linalg.yield %0 : f32
19+
}
20+
return %alloc_1 : memref<10x20xf32>
21+
}
22+
func.func @copy_without_later_use(%arg0: memref<10x20xf32>) -> memref<10x20xf32> {
23+
%cst = arith.constant 0.000000e+00 : f32
24+
%alloc = memref.alloc() {alignment = 128 : i64} : memref<10x20xf32>
25+
%alloc_0 = memref.alloc() {alignment = 128 : i64} : memref<10x20xf32>
26+
linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<10x20xf32>)
27+
%alloc_1 = memref.alloc() {alignment = 128 : i64} : memref<10x20xf32>
28+
memref.copy %alloc_0, %alloc_1 : memref<10x20xf32> to memref<10x20xf32>
29+
// CHECK-LABEL: func @copy_without_later_use
30+
// CHECK: linalg.generic
31+
// CHECK-NOT: linalg.generic
32+
linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : memref<10x20xf32>) outs(%alloc_1 : memref<10x20xf32>) attrs = {iterator_ranges = [10, 20]} {
33+
^bb0(%in: f32, %out: f32):
34+
%0 = arith.addf %out, %in : f32
35+
linalg.yield %0 : f32
36+
}
37+
return %alloc_1 : memref<10x20xf32>
38+
}
39+
}

0 commit comments

Comments
 (0)