Skip to content

Commit 644cccc

Browse files
committed
add unit tests
1 parent 14d1b19 commit 644cccc

File tree

5 files changed

+347
-0
lines changed

5 files changed

+347
-0
lines changed

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ add_subdirectory(Polynomial)
1616
add_subdirectory(SCF)
1717
add_subdirectory(SparseTensor)
1818
add_subdirectory(SPIRV)
19+
add_subdirectory(SMT)
1920
add_subdirectory(Transform)
2021
add_subdirectory(Utils)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
//===- AttributeTest.cpp - SMT attribute unit tests -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/SMT/IR/SMTAttributes.h"
10+
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
11+
#include "mlir/Dialect/SMT/IR/SMTTypes.h"
12+
#include "gtest/gtest.h"
13+
14+
using namespace mlir;
15+
using namespace smt;
16+
17+
namespace {
18+
19+
TEST(BitVectorAttrTest, MinBitWidth) {
20+
MLIRContext context;
21+
context.loadDialect<SMTDialect>();
22+
Location loc(UnknownLoc::get(&context));
23+
24+
auto attr = BitVectorAttr::getChecked(loc, &context, UINT64_C(0), 0U);
25+
ASSERT_EQ(attr, BitVectorAttr());
26+
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
27+
ASSERT_EQ(diag.str(), "bit-width must be at least 1, but got 0");
28+
});
29+
}
30+
31+
TEST(BitVectorAttrTest, ParserAndPrinterCorrect) {
32+
MLIRContext context;
33+
context.loadDialect<SMTDialect>();
34+
35+
auto attr = BitVectorAttr::get(&context, "#b1010");
36+
ASSERT_EQ(attr.getValue(), APInt(4, 10));
37+
ASSERT_EQ(attr.getType(), BitVectorType::get(&context, 4));
38+
39+
// A bit-width divisible by 4 is always printed in hex
40+
attr = BitVectorAttr::get(&context, "#b01011010");
41+
ASSERT_EQ(attr.getValueAsString(), "#x5a");
42+
43+
// A bit-width not divisible by 4 is always printed in binary
44+
// Also, make sure leading zeros are printed
45+
attr = BitVectorAttr::get(&context, "#b0101101");
46+
ASSERT_EQ(attr.getValueAsString(), "#b0101101");
47+
48+
attr = BitVectorAttr::get(&context, "#x3c");
49+
ASSERT_EQ(attr.getValueAsString(), "#x3c");
50+
51+
attr = BitVectorAttr::get(&context, "#x03c");
52+
ASSERT_EQ(attr.getValueAsString(), "#x03c");
53+
}
54+
55+
TEST(BitVectorAttrTest, ExpectedOneDigit) {
56+
MLIRContext context;
57+
context.loadDialect<SMTDialect>();
58+
Location loc(UnknownLoc::get(&context));
59+
60+
auto attr =
61+
BitVectorAttr::getChecked(loc, &context, static_cast<StringRef>("#b"));
62+
ASSERT_EQ(attr, BitVectorAttr());
63+
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
64+
ASSERT_EQ(diag.str(), "expected at least one digit");
65+
});
66+
}
67+
68+
TEST(BitVectorAttrTest, ExpectedBOrX) {
69+
MLIRContext context;
70+
context.loadDialect<SMTDialect>();
71+
Location loc(UnknownLoc::get(&context));
72+
73+
auto attr =
74+
BitVectorAttr::getChecked(loc, &context, static_cast<StringRef>("#c0"));
75+
ASSERT_EQ(attr, BitVectorAttr());
76+
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
77+
ASSERT_EQ(diag.str(), "expected either 'b' or 'x'");
78+
});
79+
}
80+
81+
TEST(BitVectorAttrTest, ExpectedHashtag) {
82+
MLIRContext context;
83+
context.loadDialect<SMTDialect>();
84+
Location loc(UnknownLoc::get(&context));
85+
86+
auto attr =
87+
BitVectorAttr::getChecked(loc, &context, static_cast<StringRef>("b0"));
88+
ASSERT_EQ(attr, BitVectorAttr());
89+
context.getDiagEngine().registerHandler(
90+
[&](Diagnostic &diag) { ASSERT_EQ(diag.str(), "expected '#'"); });
91+
}
92+
93+
TEST(BitVectorAttrTest, OutOfRange) {
94+
MLIRContext context;
95+
context.loadDialect<SMTDialect>();
96+
Location loc(UnknownLoc::get(&context));
97+
98+
auto attr1 = BitVectorAttr::getChecked(loc, &context, UINT64_C(2), 1U);
99+
auto attr63 =
100+
BitVectorAttr::getChecked(loc, &context, UINT64_C(3) << 62, 63U);
101+
ASSERT_EQ(attr1, BitVectorAttr());
102+
ASSERT_EQ(attr63, BitVectorAttr());
103+
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
104+
ASSERT_EQ(diag.str(),
105+
"value does not fit in a bit-vector of desired width");
106+
});
107+
}
108+
109+
TEST(BitVectorAttrTest, GetUInt64Max) {
110+
MLIRContext context;
111+
context.loadDialect<SMTDialect>();
112+
auto attr64 = BitVectorAttr::get(&context, UINT64_MAX, 64);
113+
auto attr65 = BitVectorAttr::get(&context, UINT64_MAX, 65);
114+
ASSERT_EQ(attr64.getValue(), APInt::getAllOnes(64));
115+
ASSERT_EQ(attr65.getValue(), APInt::getAllOnes(64).zext(65));
116+
}
117+
118+
} // namespace
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
add_mlir_unittest(MLIRSMTTests
2+
AttributeTest.cpp
3+
QuantifierTest.cpp
4+
TypeTest.cpp
5+
)
6+
7+
mlir_target_link_libraries(MLIRSMTTests
8+
PRIVATE
9+
MLIRSMT
10+
)
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
//===- QuantifierTest.cpp - SMT quantifier operation unit tests -----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/SMT/IR/SMTOps.h"
10+
#include "gtest/gtest.h"
11+
12+
using namespace mlir;
13+
using namespace smt;
14+
15+
namespace {
16+
17+
//===----------------------------------------------------------------------===//
18+
// Test custom builders of ExistsOp
19+
//===----------------------------------------------------------------------===//
20+
21+
TEST(QuantifierTest, ExistsBuilderWithPattern) {
22+
MLIRContext context;
23+
context.loadDialect<SMTDialect>();
24+
Location loc(UnknownLoc::get(&context));
25+
26+
OpBuilder builder(&context);
27+
auto boolTy = BoolType::get(&context);
28+
29+
ExistsOp existsOp = builder.create<ExistsOp>(
30+
loc, TypeRange{boolTy, boolTy},
31+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
32+
return builder.create<AndOp>(loc, boundVars);
33+
},
34+
std::nullopt,
35+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
36+
return boundVars;
37+
},
38+
/*weight=*/2);
39+
40+
SmallVector<char, 1024> buffer;
41+
llvm::raw_svector_ostream stream(buffer);
42+
existsOp.print(stream);
43+
44+
ASSERT_STREQ(
45+
stream.str().str().c_str(),
46+
"%0 = smt.exists weight 2 {\n^bb0(%arg0: !smt.bool, "
47+
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
48+
"!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
49+
"smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
50+
}
51+
52+
TEST(QuantifierTest, ExistsBuilderNoPattern) {
53+
MLIRContext context;
54+
context.loadDialect<SMTDialect>();
55+
Location loc(UnknownLoc::get(&context));
56+
57+
OpBuilder builder(&context);
58+
auto boolTy = BoolType::get(&context);
59+
60+
ExistsOp existsOp = builder.create<ExistsOp>(
61+
loc, TypeRange{boolTy, boolTy},
62+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
63+
return builder.create<AndOp>(loc, boundVars);
64+
},
65+
ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
66+
67+
SmallVector<char, 1024> buffer;
68+
llvm::raw_svector_ostream stream(buffer);
69+
existsOp.print(stream);
70+
71+
ASSERT_STREQ(stream.str().str().c_str(),
72+
"%0 = smt.exists [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
73+
"!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
74+
"smt.yield %0 : !smt.bool\n}\n");
75+
}
76+
77+
TEST(QuantifierTest, ExistsBuilderDefault) {
78+
MLIRContext context;
79+
context.loadDialect<SMTDialect>();
80+
Location loc(UnknownLoc::get(&context));
81+
82+
OpBuilder builder(&context);
83+
auto boolTy = BoolType::get(&context);
84+
85+
ExistsOp existsOp = builder.create<ExistsOp>(
86+
loc, TypeRange{boolTy, boolTy},
87+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
88+
return builder.create<AndOp>(loc, boundVars);
89+
},
90+
ArrayRef<StringRef>{"a", "b"});
91+
92+
SmallVector<char, 1024> buffer;
93+
llvm::raw_svector_ostream stream(buffer);
94+
existsOp.print(stream);
95+
96+
ASSERT_STREQ(stream.str().str().c_str(),
97+
"%0 = smt.exists [\"a\", \"b\"] {\n^bb0(%arg0: !smt.bool, "
98+
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
99+
"%0 : !smt.bool\n}\n");
100+
}
101+
102+
//===----------------------------------------------------------------------===//
103+
// Test custom builders of ForallOp
104+
//===----------------------------------------------------------------------===//
105+
106+
TEST(QuantifierTest, ForallBuilderWithPattern) {
107+
MLIRContext context;
108+
context.loadDialect<SMTDialect>();
109+
Location loc(UnknownLoc::get(&context));
110+
111+
OpBuilder builder(&context);
112+
auto boolTy = BoolType::get(&context);
113+
114+
ForallOp forallOp = builder.create<ForallOp>(
115+
loc, TypeRange{boolTy, boolTy},
116+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
117+
return builder.create<AndOp>(loc, boundVars);
118+
},
119+
ArrayRef<StringRef>{"a", "b"},
120+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
121+
return boundVars;
122+
},
123+
/*weight=*/2);
124+
125+
SmallVector<char, 1024> buffer;
126+
llvm::raw_svector_ostream stream(buffer);
127+
forallOp.print(stream);
128+
129+
ASSERT_STREQ(
130+
stream.str().str().c_str(),
131+
"%0 = smt.forall [\"a\", \"b\"] weight 2 {\n^bb0(%arg0: !smt.bool, "
132+
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
133+
"!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
134+
"smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
135+
}
136+
137+
TEST(QuantifierTest, ForallBuilderNoPattern) {
138+
MLIRContext context;
139+
context.loadDialect<SMTDialect>();
140+
Location loc(UnknownLoc::get(&context));
141+
142+
OpBuilder builder(&context);
143+
auto boolTy = BoolType::get(&context);
144+
145+
ForallOp forallOp = builder.create<ForallOp>(
146+
loc, TypeRange{boolTy, boolTy},
147+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
148+
return builder.create<AndOp>(loc, boundVars);
149+
},
150+
ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
151+
152+
SmallVector<char, 1024> buffer;
153+
llvm::raw_svector_ostream stream(buffer);
154+
forallOp.print(stream);
155+
156+
ASSERT_STREQ(stream.str().str().c_str(),
157+
"%0 = smt.forall [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
158+
"!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
159+
"smt.yield %0 : !smt.bool\n}\n");
160+
}
161+
162+
TEST(QuantifierTest, ForallBuilderDefault) {
163+
MLIRContext context;
164+
context.loadDialect<SMTDialect>();
165+
Location loc(UnknownLoc::get(&context));
166+
167+
OpBuilder builder(&context);
168+
auto boolTy = BoolType::get(&context);
169+
170+
ForallOp forallOp = builder.create<ForallOp>(
171+
loc, TypeRange{boolTy, boolTy},
172+
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
173+
return builder.create<AndOp>(loc, boundVars);
174+
},
175+
std::nullopt);
176+
177+
SmallVector<char, 1024> buffer;
178+
llvm::raw_svector_ostream stream(buffer);
179+
forallOp.print(stream);
180+
181+
ASSERT_STREQ(stream.str().str().c_str(),
182+
"%0 = smt.forall {\n^bb0(%arg0: !smt.bool, "
183+
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
184+
"%0 : !smt.bool\n}\n");
185+
}
186+
187+
} // namespace
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- TypeTest.cpp - SMT type unit tests ---------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
10+
#include "mlir/Dialect/SMT/IR/SMTTypes.h"
11+
#include "gtest/gtest.h"
12+
13+
using namespace mlir;
14+
using namespace smt;
15+
16+
namespace {
17+
18+
TEST(SMTFuncTypeTest, NonEmptyDomain) {
19+
MLIRContext context;
20+
context.loadDialect<SMTDialect>();
21+
Location loc(UnknownLoc::get(&context));
22+
23+
auto boolTy = BoolType::get(&context);
24+
auto funcTy = SMTFuncType::getChecked(loc, ArrayRef<Type>{}, boolTy);
25+
ASSERT_EQ(funcTy, Type());
26+
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
27+
ASSERT_STREQ(diag.str().c_str(), "domain must not be empty");
28+
});
29+
}
30+
31+
} // namespace

0 commit comments

Comments
 (0)