Skip to content

Commit 32b82be

Browse files
committed
Encapsulate test rewriters and use user-facing API
1 parent aefb963 commit 32b82be

File tree

1 file changed

+31
-34
lines changed

1 file changed

+31
-34
lines changed

tests/graph/rewriting/test_db.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,6 @@ def apply(self, fgraph):
2525
pass
2626

2727

28-
counter1 = 0
29-
30-
counter2 = 0
31-
32-
33-
class TestOverwrite1(GraphRewriter):
34-
def apply(self, fgraph):
35-
global counter1
36-
counter1 += 1
37-
38-
39-
class TestOverwrite2(GraphRewriter):
40-
def apply(self, fgraph):
41-
global counter2
42-
counter2 += 1
43-
44-
4528
class TestDB:
4629
def test_register(self):
4730
db = RewriteDatabase()
@@ -58,8 +41,6 @@ def test_register(self):
5841
with pytest.raises(ValueError, match=r"The tag.*"):
5942
db.register("c", NewTestRewriter()) # name taken
6043

61-
db.register("c", NewTestRewriter(), overwrite_existing=True)
62-
6344
with pytest.raises(ValueError, match=r"The tag.*"):
6445
db.register("z", TestRewriter()) # name collides with tag
6546

@@ -69,23 +50,39 @@ def test_register(self):
6950
with pytest.raises(TypeError, match=r".* is not a valid.*"):
7051
db.register("d", 1)
7152

72-
def test_overwrite(self):
73-
db = RewriteDatabase()
53+
def test_overwrite_existing(self):
54+
class TestOverwrite1(GraphRewriter):
55+
def apply(self, fgraph):
56+
fgraph.counter[0] += 1
57+
58+
class TestOverwrite2(GraphRewriter):
59+
def apply(self, fgraph):
60+
fgraph.counter[1] += 1
61+
62+
db = SequenceDB()
7463
fg = FunctionGraph([], [])
64+
fg.counter = [0, 0]
7565

76-
db.register("a", TestRewriter())
77-
Rewriter = db.__getitem__("a")
78-
Rewriter.rewrite(fg)
79-
80-
db.register("a", TestOverwrite1(), overwrite_existing=True)
81-
Rewriter = db.__getitem__("a")
82-
Rewriter.rewrite(fg)
83-
assert counter1 == 1 and counter2 == 0
84-
85-
db.register("a", TestOverwrite2(), overwrite_existing=True)
86-
Rewriter = db.__getitem__("a")
87-
Rewriter.rewrite(fg)
88-
assert counter1 == 1 and counter2 == 1
66+
db.register("a", TestRewriter(), "basic")
67+
rewriter = db.query("+basic")
68+
rewriter.rewrite(fg)
69+
assert fg.counter == [0, 0]
70+
71+
with pytest.raises(ValueError, match=r"The tag.*"):
72+
db.register("a", TestOverwrite1(), "basic")
73+
rewriter = db.query("+basic")
74+
rewriter.rewrite(fg)
75+
assert fg.counter == [0, 0]
76+
77+
db.register("a", TestOverwrite1(), "basic", overwrite_existing=True)
78+
rewriter = db.query("+basic")
79+
rewriter.rewrite(fg)
80+
assert fg.counter == [1, 0]
81+
82+
db.register("a", TestOverwrite2(), "basic", overwrite_existing=True)
83+
rewriter = db.query("+basic")
84+
rewriter.rewrite(fg)
85+
assert fg.counter == [1, 1]
8986

9087
def test_EquilibriumDB(self):
9188
eq_db = EquilibriumDB()

0 commit comments

Comments
 (0)