Skip to content

Commit 43a1b45

Browse files
committed
Revert "[mlir][SMT] remove custom forall/exists builder because of asan memory leak"
This reverts commit 54e70ac.
1 parent 336b290 commit 43a1b45

File tree

4 files changed

+232
-0
lines changed

4 files changed

+232
-0
lines changed

mlir/include/mlir/Dialect/SMT/IR/SMTOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,18 @@ class QuantifierOp<string mnemonic> : SMTOp<mnemonic, [
448448
VariadicRegion<SizedRegion<1>>:$patterns);
449449
let results = (outs BoolType:$result);
450450

451+
let builders = [
452+
OpBuilder<(ins
453+
"TypeRange":$boundVarTypes,
454+
"function_ref<Value(OpBuilder &, Location, ValueRange)>":$bodyBuilder,
455+
CArg<"std::optional<ArrayRef<StringRef>>", "std::nullopt">:$boundVarNames,
456+
CArg<"function_ref<ValueRange(OpBuilder &, Location, ValueRange)>",
457+
"{}">:$patternBuilder,
458+
CArg<"uint32_t", "0">:$weight,
459+
CArg<"bool", "false">:$noPattern)>
460+
];
461+
let skipDefaultBuilders = true;
462+
451463
let assemblyFormat = [{
452464
($boundVarNames^)? (`no_pattern` $noPattern^)? (`weight` $weight^)?
453465
attr-dict-with-keyword $body (`patterns` $patterns^)?

mlir/lib/Dialect/SMT/IR/SMTOps.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,16 @@ LogicalResult ForallOp::verifyRegions() {
432432
return verifyQuantifierRegions(*this);
433433
}
434434

435+
void ForallOp::build(
436+
OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
437+
function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
438+
std::optional<ArrayRef<StringRef>> boundVarNames,
439+
function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
440+
uint32_t weight, bool noPattern) {
441+
buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
442+
boundVarNames, patternBuilder, weight, noPattern);
443+
}
444+
435445
//===----------------------------------------------------------------------===//
436446
// ExistsOp
437447
//===----------------------------------------------------------------------===//
@@ -448,5 +458,15 @@ LogicalResult ExistsOp::verifyRegions() {
448458
return verifyQuantifierRegions(*this);
449459
}
450460

461+
void ExistsOp::build(
462+
OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
463+
function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
464+
std::optional<ArrayRef<StringRef>> boundVarNames,
465+
function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
466+
uint32_t weight, bool noPattern) {
467+
buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
468+
boundVarNames, patternBuilder, weight, noPattern);
469+
}
470+
451471
#define GET_OP_CLASSES
452472
#include "mlir/Dialect/SMT/IR/SMT.cpp.inc"

mlir/unittests/Dialect/SMT/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_unittest(MLIRSMTTests
22
AttributeTest.cpp
3+
QuantifierTest.cpp
34
TypeTest.cpp
45
)
56

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
//===- QuantifierTest.cpp - SMT quantifier operation unit tests -----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/SMT/IR/SMTOps.h"
10+
#include "gtest/gtest.h"
11+
12+
using namespace mlir;
13+
using namespace smt;
14+
15+
namespace {
16+
17+
//===----------------------------------------------------------------------===//
18+
// Test custom builders of ExistsOp
19+
//===----------------------------------------------------------------------===//
20+
21+
TEST(QuantifierTest, ExistsBuilderWithPattern) {
22+
MLIRContext context;
23+
context.loadDialect<SMTDialect>();
24+
Location loc(UnknownLoc::get(&context));
25+
26+
OpBuilder builder(&context);
27+
auto boolTy = BoolType::get(&context);
28+
29+
ExistsOp existsOp = builder.create<ExistsOp>(
30+
loc, TypeRange{boolTy, boolTy},
31+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
32+
return builder.create<AndOp>(loc, boundVars);
33+
},
34+
std::nullopt,
35+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
36+
return boundVars;
37+
},
38+
/*weight=*/2);
39+
40+
SmallVector<char, 1024> buffer;
41+
llvm::raw_svector_ostream stream(buffer);
42+
existsOp.print(stream);
43+
44+
ASSERT_STREQ(
45+
stream.str().str().c_str(),
46+
"%0 = smt.exists weight 2 {\n^bb0(%arg0: !smt.bool, "
47+
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
48+
"!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
49+
"smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
50+
51+
existsOp->destroy();
52+
}
53+
54+
TEST(QuantifierTest, ExistsBuilderNoPattern) {
55+
MLIRContext context;
56+
context.loadDialect<SMTDialect>();
57+
Location loc(UnknownLoc::get(&context));
58+
59+
OpBuilder builder(&context);
60+
auto boolTy = BoolType::get(&context);
61+
62+
ExistsOp existsOp = builder.create<ExistsOp>(
63+
loc, TypeRange{boolTy, boolTy},
64+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
65+
return builder.create<AndOp>(loc, boundVars);
66+
},
67+
ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
68+
69+
SmallVector<char, 1024> buffer;
70+
llvm::raw_svector_ostream stream(buffer);
71+
existsOp.print(stream);
72+
73+
ASSERT_STREQ(stream.str().str().c_str(),
74+
"%0 = smt.exists [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
75+
"!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
76+
"smt.yield %0 : !smt.bool\n}\n");
77+
78+
existsOp->destroy();
79+
}
80+
81+
TEST(QuantifierTest, ExistsBuilderDefault) {
82+
MLIRContext context;
83+
context.loadDialect<SMTDialect>();
84+
Location loc(UnknownLoc::get(&context));
85+
86+
OpBuilder builder(&context);
87+
auto boolTy = BoolType::get(&context);
88+
89+
ExistsOp existsOp = builder.create<ExistsOp>(
90+
loc, TypeRange{boolTy, boolTy},
91+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
92+
return builder.create<AndOp>(loc, boundVars);
93+
},
94+
ArrayRef<StringRef>{"a", "b"});
95+
96+
SmallVector<char, 1024> buffer;
97+
llvm::raw_svector_ostream stream(buffer);
98+
existsOp.print(stream);
99+
100+
ASSERT_STREQ(stream.str().str().c_str(),
101+
"%0 = smt.exists [\"a\", \"b\"] {\n^bb0(%arg0: !smt.bool, "
102+
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
103+
"%0 : !smt.bool\n}\n");
104+
105+
existsOp->destroy();
106+
}
107+
108+
//===----------------------------------------------------------------------===//
109+
// Test custom builders of ForallOp
110+
//===----------------------------------------------------------------------===//
111+
112+
TEST(QuantifierTest, ForallBuilderWithPattern) {
113+
MLIRContext context;
114+
context.loadDialect<SMTDialect>();
115+
Location loc(UnknownLoc::get(&context));
116+
117+
OpBuilder builder(&context);
118+
auto boolTy = BoolType::get(&context);
119+
120+
ForallOp forallOp = builder.create<ForallOp>(
121+
loc, TypeRange{boolTy, boolTy},
122+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
123+
return builder.create<AndOp>(loc, boundVars);
124+
},
125+
ArrayRef<StringRef>{"a", "b"},
126+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
127+
return boundVars;
128+
},
129+
/*weight=*/2);
130+
131+
SmallVector<char, 1024> buffer;
132+
llvm::raw_svector_ostream stream(buffer);
133+
forallOp.print(stream);
134+
135+
ASSERT_STREQ(
136+
stream.str().str().c_str(),
137+
"%0 = smt.forall [\"a\", \"b\"] weight 2 {\n^bb0(%arg0: !smt.bool, "
138+
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
139+
"!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
140+
"smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
141+
142+
forallOp->destroy();
143+
}
144+
145+
TEST(QuantifierTest, ForallBuilderNoPattern) {
146+
MLIRContext context;
147+
context.loadDialect<SMTDialect>();
148+
Location loc(UnknownLoc::get(&context));
149+
150+
OpBuilder builder(&context);
151+
auto boolTy = BoolType::get(&context);
152+
153+
ForallOp forallOp = builder.create<ForallOp>(
154+
loc, TypeRange{boolTy, boolTy},
155+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
156+
return builder.create<AndOp>(loc, boundVars);
157+
},
158+
ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
159+
160+
SmallVector<char, 1024> buffer;
161+
llvm::raw_svector_ostream stream(buffer);
162+
forallOp.print(stream);
163+
164+
ASSERT_STREQ(stream.str().str().c_str(),
165+
"%0 = smt.forall [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
166+
"!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
167+
"smt.yield %0 : !smt.bool\n}\n");
168+
169+
forallOp->destroy();
170+
}
171+
172+
TEST(QuantifierTest, ForallBuilderDefault) {
173+
MLIRContext context;
174+
context.loadDialect<SMTDialect>();
175+
Location loc(UnknownLoc::get(&context));
176+
177+
OpBuilder builder(&context);
178+
auto boolTy = BoolType::get(&context);
179+
180+
ForallOp forallOp = builder.create<ForallOp>(
181+
loc, TypeRange{boolTy, boolTy},
182+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
183+
return builder.create<AndOp>(loc, boundVars);
184+
},
185+
std::nullopt);
186+
187+
SmallVector<char, 1024> buffer;
188+
llvm::raw_svector_ostream stream(buffer);
189+
forallOp.print(stream);
190+
191+
ASSERT_STREQ(stream.str().str().c_str(),
192+
"%0 = smt.forall {\n^bb0(%arg0: !smt.bool, "
193+
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
194+
"%0 : !smt.bool\n}\n");
195+
196+
forallOp->destroy();
197+
}
198+
199+
} // namespace

0 commit comments

Comments
 (0)