From aefb9634e974822adb0a72fe2d98fd7b1abf453b Mon Sep 17 00:00:00 2001 From: Adv Date: Thu, 12 Dec 2024 16:52:03 +0530 Subject: [PATCH 1/2] Add 'overwrite_existing' flag to allow graph rewrites and include appropriate testing --- pytensor/graph/rewriting/db.py | 31 +++++++++++++-------- tests/graph/rewriting/test_db.py | 47 +++++++++++++++++++++++++++++++- 2 files changed, 65 insertions(+), 13 deletions(-) diff --git a/pytensor/graph/rewriting/db.py b/pytensor/graph/rewriting/db.py index f6cfac3a76..fb81622458 100644 --- a/pytensor/graph/rewriting/db.py +++ b/pytensor/graph/rewriting/db.py @@ -35,6 +35,7 @@ def register( rewriter: Union["RewriteDatabase", RewritesType], *tags: str, use_db_name_as_tag=True, + overwrite_existing=False, ): """Register a new rewriter to the database. @@ -56,7 +57,8 @@ def register( ``local_remove_all_assert``. Setting `use_db_name_as_tag` to ``False`` removes that behavior. This means that only the rewrite's name and/or its tags will enable it. - + overwrite_existing: + Overwrite the existing rewriter with a new one having the same name """ if not isinstance( rewriter, @@ -66,22 +68,27 @@ def register( ): raise TypeError(f"{rewriter} is not a valid rewrite type.") - if name in self.__db__: - raise ValueError(f"The tag '{name}' is already present in the database.") - if use_db_name_as_tag: if self.name is not None: tags = (*tags, self.name) rewriter.name = name - # This restriction is there because in many place we suppose that - # something in the RewriteDatabase is there only once. - if rewriter.name in self.__db__: - raise ValueError( - f"Tried to register {rewriter.name} again under the new name {name}. " - "The same rewrite cannot be registered multiple times in" - " an `RewriteDatabase`; use `ProxyDB` instead." - ) + + # if tag collides with name + if name in self.__db__ and name not in self._names: + raise ValueError(f"The tag '{name}' is already present in the database.") + + if name in self.__db__ or rewriter.name in self.__db__: + if overwrite_existing: + self.remove_tags(name, *tags) + old_rewriter = self.__db__[name].pop() + self._names.remove(name) + self.__db__[old_rewriter.__class__.__name__].remove(old_rewriter) + else: + raise ValueError( + f"The tag '{name}' is already present in the database." + ) + self.__db__[name] = OrderedSet([rewriter]) self._names.add(name) self.__db__[rewriter.__class__.__name__].add(rewriter) diff --git a/tests/graph/rewriting/test_db.py b/tests/graph/rewriting/test_db.py index ec790dbfe2..38e3c865c4 100644 --- a/tests/graph/rewriting/test_db.py +++ b/tests/graph/rewriting/test_db.py @@ -1,5 +1,6 @@ import pytest +from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import GraphRewriter, SequentialGraphRewriter from pytensor.graph.rewriting.db import ( EquilibriumDB, @@ -17,6 +18,30 @@ def apply(self, fgraph): pass +class NewTestRewriter(GraphRewriter): + name = "bleh" + + def apply(self, fgraph): + pass + + +counter1 = 0 + +counter2 = 0 + + +class TestOverwrite1(GraphRewriter): + def apply(self, fgraph): + global counter1 + counter1 += 1 + + +class TestOverwrite2(GraphRewriter): + def apply(self, fgraph): + global counter2 + counter2 += 1 + + class TestDB: def test_register(self): db = RewriteDatabase() @@ -31,7 +56,9 @@ def test_register(self): assert "c" in db with pytest.raises(ValueError, match=r"The tag.*"): - db.register("c", TestRewriter()) # name taken + db.register("c", NewTestRewriter()) # name taken + + db.register("c", NewTestRewriter(), overwrite_existing=True) with pytest.raises(ValueError, match=r"The tag.*"): db.register("z", TestRewriter()) # name collides with tag @@ -42,6 +69,24 @@ def test_register(self): with pytest.raises(TypeError, match=r".* is not a valid.*"): db.register("d", 1) + def test_overwrite(self): + db = RewriteDatabase() + fg = FunctionGraph([], []) + + db.register("a", TestRewriter()) + Rewriter = db.__getitem__("a") + Rewriter.rewrite(fg) + + db.register("a", TestOverwrite1(), overwrite_existing=True) + Rewriter = db.__getitem__("a") + Rewriter.rewrite(fg) + assert counter1 == 1 and counter2 == 0 + + db.register("a", TestOverwrite2(), overwrite_existing=True) + Rewriter = db.__getitem__("a") + Rewriter.rewrite(fg) + assert counter1 == 1 and counter2 == 1 + def test_EquilibriumDB(self): eq_db = EquilibriumDB() From 32b82be4d7e8ed51563e21158635550eea6deb4e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 28 Jan 2025 16:41:03 +0100 Subject: [PATCH 2/2] Encapsulate test rewriters and use user-facing API --- tests/graph/rewriting/test_db.py | 65 +++++++++++++++----------------- 1 file changed, 31 insertions(+), 34 deletions(-) diff --git a/tests/graph/rewriting/test_db.py b/tests/graph/rewriting/test_db.py index 38e3c865c4..5d0c98a6b0 100644 --- a/tests/graph/rewriting/test_db.py +++ b/tests/graph/rewriting/test_db.py @@ -25,23 +25,6 @@ def apply(self, fgraph): pass -counter1 = 0 - -counter2 = 0 - - -class TestOverwrite1(GraphRewriter): - def apply(self, fgraph): - global counter1 - counter1 += 1 - - -class TestOverwrite2(GraphRewriter): - def apply(self, fgraph): - global counter2 - counter2 += 1 - - class TestDB: def test_register(self): db = RewriteDatabase() @@ -58,8 +41,6 @@ def test_register(self): with pytest.raises(ValueError, match=r"The tag.*"): db.register("c", NewTestRewriter()) # name taken - db.register("c", NewTestRewriter(), overwrite_existing=True) - with pytest.raises(ValueError, match=r"The tag.*"): db.register("z", TestRewriter()) # name collides with tag @@ -69,23 +50,39 @@ def test_register(self): with pytest.raises(TypeError, match=r".* is not a valid.*"): db.register("d", 1) - def test_overwrite(self): - db = RewriteDatabase() + def test_overwrite_existing(self): + class TestOverwrite1(GraphRewriter): + def apply(self, fgraph): + fgraph.counter[0] += 1 + + class TestOverwrite2(GraphRewriter): + def apply(self, fgraph): + fgraph.counter[1] += 1 + + db = SequenceDB() fg = FunctionGraph([], []) + fg.counter = [0, 0] - db.register("a", TestRewriter()) - Rewriter = db.__getitem__("a") - Rewriter.rewrite(fg) - - db.register("a", TestOverwrite1(), overwrite_existing=True) - Rewriter = db.__getitem__("a") - Rewriter.rewrite(fg) - assert counter1 == 1 and counter2 == 0 - - db.register("a", TestOverwrite2(), overwrite_existing=True) - Rewriter = db.__getitem__("a") - Rewriter.rewrite(fg) - assert counter1 == 1 and counter2 == 1 + db.register("a", TestRewriter(), "basic") + rewriter = db.query("+basic") + rewriter.rewrite(fg) + assert fg.counter == [0, 0] + + with pytest.raises(ValueError, match=r"The tag.*"): + db.register("a", TestOverwrite1(), "basic") + rewriter = db.query("+basic") + rewriter.rewrite(fg) + assert fg.counter == [0, 0] + + db.register("a", TestOverwrite1(), "basic", overwrite_existing=True) + rewriter = db.query("+basic") + rewriter.rewrite(fg) + assert fg.counter == [1, 0] + + db.register("a", TestOverwrite2(), "basic", overwrite_existing=True) + rewriter = db.query("+basic") + rewriter.rewrite(fg) + assert fg.counter == [1, 1] def test_EquilibriumDB(self): eq_db = EquilibriumDB()