@@ -25,23 +25,6 @@ def apply(self, fgraph):
2525 pass
2626
2727
28- counter1 = 0
29-
30- counter2 = 0
31-
32-
33- class TestOverwrite1 (GraphRewriter ):
34- def apply (self , fgraph ):
35- global counter1
36- counter1 += 1
37-
38-
39- class TestOverwrite2 (GraphRewriter ):
40- def apply (self , fgraph ):
41- global counter2
42- counter2 += 1
43-
44-
4528class TestDB :
4629 def test_register (self ):
4730 db = RewriteDatabase ()
@@ -58,8 +41,6 @@ def test_register(self):
5841 with pytest .raises (ValueError , match = r"The tag.*" ):
5942 db .register ("c" , NewTestRewriter ()) # name taken
6043
61- db .register ("c" , NewTestRewriter (), overwrite_existing = True )
62-
6344 with pytest .raises (ValueError , match = r"The tag.*" ):
6445 db .register ("z" , TestRewriter ()) # name collides with tag
6546
@@ -69,23 +50,39 @@ def test_register(self):
6950 with pytest .raises (TypeError , match = r".* is not a valid.*" ):
7051 db .register ("d" , 1 )
7152
72- def test_overwrite (self ):
73- db = RewriteDatabase ()
53+ def test_overwrite_existing (self ):
54+ class TestOverwrite1 (GraphRewriter ):
55+ def apply (self , fgraph ):
56+ fgraph .counter [0 ] += 1
57+
58+ class TestOverwrite2 (GraphRewriter ):
59+ def apply (self , fgraph ):
60+ fgraph .counter [1 ] += 1
61+
62+ db = SequenceDB ()
7463 fg = FunctionGraph ([], [])
64+ fg .counter = [0 , 0 ]
7565
76- db .register ("a" , TestRewriter ())
77- Rewriter = db .__getitem__ ("a" )
78- Rewriter .rewrite (fg )
79-
80- db .register ("a" , TestOverwrite1 (), overwrite_existing = True )
81- Rewriter = db .__getitem__ ("a" )
82- Rewriter .rewrite (fg )
83- assert counter1 == 1 and counter2 == 0
84-
85- db .register ("a" , TestOverwrite2 (), overwrite_existing = True )
86- Rewriter = db .__getitem__ ("a" )
87- Rewriter .rewrite (fg )
88- assert counter1 == 1 and counter2 == 1
66+ db .register ("a" , TestRewriter (), "basic" )
67+ rewriter = db .query ("+basic" )
68+ rewriter .rewrite (fg )
69+ assert fg .counter == [0 , 0 ]
70+
71+ with pytest .raises (ValueError , match = r"The tag.*" ):
72+ db .register ("a" , TestOverwrite1 (), "basic" )
73+ rewriter = db .query ("+basic" )
74+ rewriter .rewrite (fg )
75+ assert fg .counter == [0 , 0 ]
76+
77+ db .register ("a" , TestOverwrite1 (), "basic" , overwrite_existing = True )
78+ rewriter = db .query ("+basic" )
79+ rewriter .rewrite (fg )
80+ assert fg .counter == [1 , 0 ]
81+
82+ db .register ("a" , TestOverwrite2 (), "basic" , overwrite_existing = True )
83+ rewriter = db .query ("+basic" )
84+ rewriter .rewrite (fg )
85+ assert fg .counter == [1 , 1 ]
8986
9087 def test_EquilibriumDB (self ):
9188 eq_db = EquilibriumDB ()
0 commit comments