Skip to content

Commit a8cbb3e

Browse files
authored
Fix scf raising on scanf test (#286)
1 parent dceabaa commit a8cbb3e

File tree

3 files changed

+91
-71
lines changed

3 files changed

+91
-71
lines changed

lib/polygeist/Ops.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,52 @@ bool collectEffects(Operation *op,
8484
return true;
8585
}
8686

87+
if (auto cop = dyn_cast<LLVM::CallOp>(op)) {
88+
if (auto callee = cop.getCallee()) {
89+
if (*callee == "scanf" || *callee == "__isoc99_scanf") {
90+
// Global read
91+
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Read>());
92+
93+
bool first = true;
94+
for (auto arg : cop.getArgOperands()) {
95+
if (first)
96+
effects.emplace_back(::mlir::MemoryEffects::Read::get(), arg,
97+
::mlir::SideEffects::DefaultResource::get());
98+
else
99+
effects.emplace_back(::mlir::MemoryEffects::Write::get(), arg,
100+
::mlir::SideEffects::DefaultResource::get());
101+
first = false;
102+
}
103+
104+
return true;
105+
}
106+
if (*callee == "printf") {
107+
// Global read
108+
effects.emplace_back(
109+
MemoryEffects::Effect::get<MemoryEffects::Write>());
110+
for (auto arg : cop.getArgOperands()) {
111+
effects.emplace_back(::mlir::MemoryEffects::Read::get(), arg,
112+
::mlir::SideEffects::DefaultResource::get());
113+
}
114+
return true;
115+
}
116+
if (*callee == "free") {
117+
for (auto arg : cop.getArgOperands()) {
118+
effects.emplace_back(::mlir::MemoryEffects::Free::get(), arg,
119+
::mlir::SideEffects::DefaultResource::get());
120+
}
121+
return true;
122+
}
123+
if (*callee == "strlen") {
124+
for (auto arg : cop.getArgOperands()) {
125+
effects.emplace_back(::mlir::MemoryEffects::Read::get(), arg,
126+
::mlir::SideEffects::DefaultResource::get());
127+
}
128+
return true;
129+
}
130+
}
131+
}
132+
87133
// We need to be conservative here in case the op doesn't have the interface
88134
// and assume it can have any possible effect.
89135
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Read>());
@@ -324,6 +370,8 @@ class BarrierHoist final : public OpRewritePattern<BarrierOp> {
324370
}
325371
};
326372

373+
extern std::set<std::string> NonCapturingFunctions;
374+
327375
bool isCaptured(Value v, Operation *potentialUser = nullptr,
328376
bool *seenuse = nullptr) {
329377
SmallVector<Value> todo = {v};
@@ -383,6 +431,16 @@ bool isCaptured(Value v, Operation *potentialUser = nullptr,
383431
if (auto sub = dyn_cast<polygeist::Pointer2MemrefOp>(u)) {
384432
todo.push_back(sub);
385433
}
434+
if (auto cop = dyn_cast<LLVM::CallOp>(u)) {
435+
if (auto callee = cop.getCallee()) {
436+
if (NonCapturingFunctions.count(callee->str()))
437+
continue;
438+
}
439+
}
440+
if (auto cop = dyn_cast<func::CallOp>(u)) {
441+
if (NonCapturingFunctions.count(cop.getCallee().str()))
442+
continue;
443+
}
386444
return true;
387445
}
388446
}
@@ -493,6 +551,8 @@ bool mayAlias(MemoryEffects::EffectInstance a,
493551
MemoryEffects::EffectInstance b) {
494552
if (Value v2 = b.getValue()) {
495553
return mayAlias(a, v2);
554+
} else if (Value v = a.getValue()) {
555+
return mayAlias(b, v);
496556
}
497557
return true;
498558
}

lib/polygeist/Passes/Mem2Reg.cpp

Lines changed: 16 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,10 +1067,10 @@ void removeRedundantBlockArgs(
10671067
}
10681068

10691069
std::set<std::string> NonCapturingFunctions = {
1070-
"free", "printf", "fprintf", "scanf", "fscanf",
1071-
"gettimeofday", "clock_gettime", "getenv", "strrchr", "strlen",
1072-
"sprintf", "sscanf", "mkdir", "fwrite", "fread",
1073-
"memcpy", "cudaMemcpy", "memset", "cudaMemset"};
1070+
"free", "printf", "fprintf", "scanf", "fscanf",
1071+
"gettimeofday", "clock_gettime", "getenv", "strrchr", "strlen",
1072+
"sprintf", "sscanf", "mkdir", "fwrite", "fread",
1073+
"memcpy", "cudaMemcpy", "memset", "cudaMemset", "__isoc99_scanf"};
10741074
// fopen, fclose
10751075
std::set<std::string> NoWriteFunctions = {"exit", "__errno_location"};
10761076
// This is a straightforward implementation not optimized for speed. Optimize
@@ -1286,62 +1286,22 @@ bool Mem2Reg::forwardStoreToLoad(
12861286
if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
12871287
return;
12881288
if (auto callOp = dyn_cast<mlir::LLVM::CallOp>(op)) {
1289-
if (callOp.getCallee() && (*callOp.getCallee() == "printf" ||
1290-
*callOp.getCallee() == "free" ||
1291-
*callOp.getCallee() == "strlen")) {
1289+
if (callOp.getCallee() && (*callOp.getCallee() == "printf")) {
12921290
return;
12931291
}
12941292
}
1295-
MemoryEffectOpInterface interface =
1296-
dyn_cast<MemoryEffectOpInterface>(op);
1297-
if (!interface)
1298-
opMayHaveEffect = true;
1299-
if (interface) {
1300-
SmallVector<MemoryEffects::EffectInstance, 1> effects;
1301-
interface.getEffects(effects);
1302-
1303-
for (auto effect : effects) {
1304-
// If op causes EffectType on a potentially aliasing location for
1305-
// memOp, mark as having the effect.
1306-
if (isa<MemoryEffects::Write>(effect.getEffect())) {
1307-
if (Value val = effect.getValue()) {
1308-
while (true) {
1309-
if (auto co = val.getDefiningOp<memref::CastOp>())
1310-
val = co.getSource();
1311-
else if (auto co = val.getDefiningOp<polygeist::SubIndexOp>())
1312-
val = co.getSource();
1313-
else if (auto co =
1314-
val.getDefiningOp<polygeist::Memref2PointerOp>())
1315-
val = co.getSource();
1316-
else if (auto co =
1317-
val.getDefiningOp<polygeist::Pointer2MemrefOp>())
1318-
val = co.getSource();
1319-
else if (auto co = val.getDefiningOp<LLVM::BitcastOp>())
1320-
val = co.getArg();
1321-
else if (auto co = val.getDefiningOp<LLVM::AddrSpaceCastOp>())
1322-
val = co.getArg();
1323-
else if (auto co = val.getDefiningOp<LLVM::GEPOp>())
1324-
val = co.getBase();
1325-
else
1326-
break;
1327-
}
1328-
if (val.getDefiningOp<memref::AllocaOp>() ||
1329-
val.getDefiningOp<memref::AllocOp>() ||
1330-
val.getDefiningOp<LLVM::AllocaOp>()) {
1331-
if (val != AI)
1332-
continue;
1333-
}
1334-
if (auto glob = val.getDefiningOp<memref::GetGlobalOp>()) {
1335-
if (auto Aglob = AI.getDefiningOp<memref::GetGlobalOp>()) {
1336-
if (glob.getName() != Aglob.getName())
1337-
continue;
1338-
} else
1339-
continue;
1340-
}
1341-
}
1342-
opMayHaveEffect = true;
1343-
break;
1293+
SmallVector<MemoryEffects::EffectInstance, 1> effects;
1294+
collectEffects(op, effects, /*considerBarrier*/ false);
1295+
1296+
for (auto effect : effects) {
1297+
// If op causes EffectType on a potentially aliasing location for
1298+
// memOp, mark as having the effect.
1299+
if (isa<MemoryEffects::Write>(effect.getEffect())) {
1300+
if (!mayAlias(effect.getEffect(), AI)) {
1301+
continue;
13441302
}
1303+
opMayHaveEffect = true;
1304+
break;
13451305
}
13461306
}
13471307
if (opMayHaveEffect) {
Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: cgeist %s %[[stdinclude:.+]] --function=alloc -S | FileCheck %s
1+
// RUN: cgeist %s %stdinclude --function=alloc -S | FileCheck %s
22

33
#include <stdio.h>
44
#include <stdlib.h>
@@ -19,35 +19,35 @@ int* alloc() {
1919
return h_graph_nodes;
2020
}
2121

22-
// XFAIL: *
23-
// TODO INVESTIGATE WHY SCF.FOR NO LONGER CREATED / NO LICM
24-
2522
// CHECK: llvm.mlir.global internal constant @str1("%d\0A\00")
2623
// CHECK-NEXT: llvm.func @__isoc99_scanf(!llvm.ptr<i8>, ...) -> i32
2724
// CHECK-NEXT: llvm.mlir.global internal constant @str0("%d\00")
2825
// CHECK-NEXT: func @alloc() -> memref<?xi32>
26+
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
27+
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
2928
// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index
3029
// CHECK-DAG: %[[c4_i64:.+]] = arith.constant 4 : i6
31-
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
32-
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
33-
// CHECK-DAG: %[[c1_i64:.+]] = arith.constant 1 : i64
34-
// CHECK-NEXT: %[[V0:.+]] = llvm.alloca %[[c1_i64]] x i32 : (i64) -> !llvm.ptr<i32>
30+
// CHECK-DAG: %[[ud:.+]] = llvm.mlir.undef : i32
31+
// CHECK-NEXT: %[[alloca:.+]] = memref.alloca() : memref<1xi32>
32+
// CHECK-NEXT: affine.store %[[ud]], %[[alloca]][0] : memref<1xi32>
3533
// CHECK-NEXT: %[[V1:.+]] = llvm.mlir.addressof @str0 : !llvm.ptr<array<3 x i8>>
3634
// CHECK-NEXT: %[[V2:.+]] = llvm.getelementptr %[[V1]][0, 0] : (!llvm.ptr<array<3 x i8>>) -> !llvm.ptr<i8>
37-
// CHECK-NEXT: %[[V3:.+]] = llvm.call @__isoc99_scanf(%[[V2]], %[[V0]]) : (!llvm.ptr<i8>, !llvm.ptr<i32>) -> i32
38-
// CHECK-NEXT: %[[V4:.+]] = llvm.load %[[V0]] : !llvm.ptr<i32>
35+
// CHECK-NEXT: %[[S0:.+]] = "polygeist.memref2pointer"(%[[alloca]]) : (memref<1xi32>) -> !llvm.ptr<i32>
36+
// CHECK-NEXT: %[[V3:.+]] = llvm.call @__isoc99_scanf(%[[V2]], %[[S0]]) : (!llvm.ptr<i8>, !llvm.ptr<i32>) -> i32
37+
// CHECK-NEXT: %[[V4:.+]] = affine.load %[[alloca]][0] : memref<1xi32>
3938
// CHECK-NEXT: %[[V5:.+]] = arith.extsi %[[V4]] : i32 to i64
4039
// CHECK-NEXT: %[[V6:.+]] = arith.muli %[[V5]], %[[c4_i64]] : i64
4140
// CHECK-NEXT: %[[V7:.+]] = arith.index_cast %[[V6]] : i64 to index
4241
// CHECK-NEXT: %[[V8:.+]] = arith.divui %[[V7]], %[[c4]] : index
4342
// CHECK-NEXT: %[[i8:.+]] = memref.alloc(%[[V8]]) : memref<?xi32>
43+
// CHECK-NEXT: %[[i9:.+]] = llvm.mlir.addressof @str1 : !llvm.ptr<array<4 x i8>>
44+
// CHECK-NEXT: %[[i10:.+]] = llvm.getelementptr %[[i9]][0, 0] : (!llvm.ptr<array<4 x i8>>) -> !llvm.ptr<i8>
45+
// CHECK-NEXT: %[[V12:.+]] = "polygeist.memref2pointer"(%[[i8]]) : (memref<?xi32>) -> !llvm.ptr<i32>
4446
// CHECK-NEXT: %[[n:.+]] = arith.index_cast %[[V4]] : i32 to index
45-
// CHECK-NEXT: %[[i9:.+]] = llvm.mlir.addressof @str1 : !llvm.ptr<array<4 x i8>>
46-
// CHECK-NEXT: %[[i10:.+]] = llvm.getelementptr %[[i9]][0, 0] : (!llvm.ptr<array<4 x i8>>) -> !llvm.ptr<i8>
4747
// CHECK-NEXT: scf.for %[[arg0:.+]] = %[[c0]] to %[[n]] step %[[c1]] {
48-
// CHECK-NEXT: %[[i13:.+]] = llvm.call @__isoc99_scanf(%[[i10]], %[[V0]]) : (!llvm.ptr<i8>, !llvm.ptr<i32>) -> i32
49-
// CHECK-NEXT: %[[i12:.+]] = llvm.load %[[V0]] : !llvm.ptr<i32>
50-
// CHECK-NEXT: memref.store %[[i12]], %[[i8]][%[[arg0]]] : memref<?xi32>
48+
// CHECK-NEXT: %[[i14:.+]] = arith.index_cast %[[arg0]] : index to i64
49+
// CHECK-NEXT: %[[i15:.+]] = llvm.getelementptr %[[V12]][%[[i14]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
50+
// CHECK-NEXT: %[[i13:.+]] = llvm.call @__isoc99_scanf(%[[i10]], %[[i15]]) : (!llvm.ptr<i8>, !llvm.ptr<i32>) -> i32
5151
// CHECK-NEXT: }
5252
// CHECK-NEXT: return %[[i8]] : memref<?xi32>
5353
// CHECK-NEXT: }

0 commit comments

Comments
 (0)