Skip to content

Commit 2dded82

Browse files
authored
Optimize memcpy to scf (#260)
1 parent 510bc58 commit 2dded82

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

lib/polygeist/Ops.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,12 @@ class CopySimplification final : public OpRewritePattern<T> {
12181218
todo.push_back(ext.getIn());
12191219
else if (auto ext = len.getDefiningOp<arith::ExtSIOp>())
12201220
todo.push_back(ext.getIn());
1221-
else if (auto ext = len.getDefiningOp<arith::IndexCastOp>())
1221+
else if (auto ext = len.getDefiningOp<arith::TruncIOp>()) {
1222+
if (APInt(64, width).isPowerOf2() &&
1223+
ext.getType().getIntOrFloatBitWidth() >
1224+
APInt(64, width).nearestLogBase2())
1225+
todo.push_back(ext.getIn());
1226+
} else if (auto ext = len.getDefiningOp<arith::IndexCastOp>())
12221227
todo.push_back(ext.getIn());
12231228
else if (auto mul = len.getDefiningOp<arith::MulIOp>()) {
12241229
todo.push_back(mul.getLhs());
@@ -1314,7 +1319,12 @@ class SetSimplification final : public OpRewritePattern<T> {
13141319
todo.push_back(ext.getIn());
13151320
else if (auto ext = len.getDefiningOp<arith::ExtSIOp>())
13161321
todo.push_back(ext.getIn());
1317-
else if (auto ext = len.getDefiningOp<arith::IndexCastOp>())
1322+
else if (auto ext = len.getDefiningOp<arith::TruncIOp>()) {
1323+
if (APInt(64, width).isPowerOf2() &&
1324+
ext.getType().getIntOrFloatBitWidth() >
1325+
APInt(64, width).nearestLogBase2())
1326+
todo.push_back(ext.getIn());
1327+
} else if (auto ext = len.getDefiningOp<arith::IndexCastOp>())
13181328
todo.push_back(ext.getIn());
13191329
else if (auto mul = len.getDefiningOp<arith::MulIOp>()) {
13201330
todo.push_back(mul.getLhs());

test/polygeist-opt/copyopt.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: polygeist-opt --canonicalize --split-input-file %s | FileCheck %s
2+
3+
module {
4+
func.func @cpy(%46: i64, %66: memref<?xi32>, %51: memref<?xi32>) {
5+
%c4_i64 = arith.constant 4 : i64
6+
%false = arith.constant false
7+
%47 = arith.muli %46, %c4_i64 : i64
8+
%48 = arith.trunci %47 : i64 to i32
9+
%67 = "polygeist.memref2pointer"(%66) : (memref<?xi32>) -> !llvm.ptr<i8>
10+
%68 = "polygeist.memref2pointer"(%51) : (memref<?xi32>) -> !llvm.ptr<i8>
11+
%69 = arith.extsi %48 : i32 to i64
12+
"llvm.intr.memcpy"(%67, %68, %69, %false) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i64, i1) -> ()
13+
return
14+
}
15+
}
16+
17+
// CHECK: func.func @cpy(%arg0: i64, %arg1: memref<?xi32>, %arg2: memref<?xi32>) {
18+
// CHECK-NEXT: %c0 = arith.constant 0 : index
19+
// CHECK-NEXT: %c1 = arith.constant 1 : index
20+
// CHECK-NEXT: %c4 = arith.constant 4 : index
21+
// CHECK-NEXT: %c4_i64 = arith.constant 4 : i64
22+
// CHECK-NEXT: %0 = arith.muli %arg0, %c4_i64 : i64
23+
// CHECK-NEXT: %1 = arith.trunci %0 : i64 to i32
24+
// CHECK-NEXT: %2 = arith.index_cast %1 : i32 to index
25+
// CHECK-NEXT: %3 = arith.divui %2, %c4 : index
26+
// CHECK-NEXT: scf.for %arg3 = %c0 to %3 step %c1 {
27+
// CHECK-NEXT: %4 = memref.load %arg2[%arg3] : memref<?xi32>
28+
// CHECK-NEXT: memref.store %4, %arg1[%arg3] : memref<?xi32>
29+
// CHECK-NEXT: }
30+
// CHECK-NEXT: return
31+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)