diff --git a/pytensor/graph/rewriting/db.py b/pytensor/graph/rewriting/db.py index f6cfac3a76..fb81622458 100644 --- a/pytensor/graph/rewriting/db.py +++ b/pytensor/graph/rewriting/db.py @@ -35,6 +35,7 @@ def register( rewriter: Union["RewriteDatabase", RewritesType], *tags: str, use_db_name_as_tag=True, + overwrite_existing=False, ): """Register a new rewriter to the database. @@ -56,7 +57,8 @@ def register( ``local_remove_all_assert``. Setting `use_db_name_as_tag` to ``False`` removes that behavior. This means that only the rewrite's name and/or its tags will enable it. - + overwrite_existing: + Overwrite the existing rewriter with a new one having the same name """ if not isinstance( rewriter, @@ -66,22 +68,27 @@ def register( ): raise TypeError(f"{rewriter} is not a valid rewrite type.") - if name in self.__db__: - raise ValueError(f"The tag '{name}' is already present in the database.") - if use_db_name_as_tag: if self.name is not None: tags = (*tags, self.name) rewriter.name = name - # This restriction is there because in many place we suppose that - # something in the RewriteDatabase is there only once. - if rewriter.name in self.__db__: - raise ValueError( - f"Tried to register {rewriter.name} again under the new name {name}. " - "The same rewrite cannot be registered multiple times in" - " an `RewriteDatabase`; use `ProxyDB` instead." - ) + + # if tag collides with name + if name in self.__db__ and name not in self._names: + raise ValueError(f"The tag '{name}' is already present in the database.") + + if name in self.__db__ or rewriter.name in self.__db__: + if overwrite_existing: + self.remove_tags(name, *tags) + old_rewriter = self.__db__[name].pop() + self._names.remove(name) + self.__db__[old_rewriter.__class__.__name__].remove(old_rewriter) + else: + raise ValueError( + f"The tag '{name}' is already present in the database." + ) + self.__db__[name] = OrderedSet([rewriter]) self._names.add(name) self.__db__[rewriter.__class__.__name__].add(rewriter) diff --git a/tests/graph/rewriting/test_db.py b/tests/graph/rewriting/test_db.py index ec790dbfe2..5d0c98a6b0 100644 --- a/tests/graph/rewriting/test_db.py +++ b/tests/graph/rewriting/test_db.py @@ -1,5 +1,6 @@ import pytest +from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import GraphRewriter, SequentialGraphRewriter from pytensor.graph.rewriting.db import ( EquilibriumDB, @@ -17,6 +18,13 @@ def apply(self, fgraph): pass +class NewTestRewriter(GraphRewriter): + name = "bleh" + + def apply(self, fgraph): + pass + + class TestDB: def test_register(self): db = RewriteDatabase() @@ -31,7 +39,7 @@ def test_register(self): assert "c" in db with pytest.raises(ValueError, match=r"The tag.*"): - db.register("c", TestRewriter()) # name taken + db.register("c", NewTestRewriter()) # name taken with pytest.raises(ValueError, match=r"The tag.*"): db.register("z", TestRewriter()) # name collides with tag @@ -42,6 +50,40 @@ def test_register(self): with pytest.raises(TypeError, match=r".* is not a valid.*"): db.register("d", 1) + def test_overwrite_existing(self): + class TestOverwrite1(GraphRewriter): + def apply(self, fgraph): + fgraph.counter[0] += 1 + + class TestOverwrite2(GraphRewriter): + def apply(self, fgraph): + fgraph.counter[1] += 1 + + db = SequenceDB() + fg = FunctionGraph([], []) + fg.counter = [0, 0] + + db.register("a", TestRewriter(), "basic") + rewriter = db.query("+basic") + rewriter.rewrite(fg) + assert fg.counter == [0, 0] + + with pytest.raises(ValueError, match=r"The tag.*"): + db.register("a", TestOverwrite1(), "basic") + rewriter = db.query("+basic") + rewriter.rewrite(fg) + assert fg.counter == [0, 0] + + db.register("a", TestOverwrite1(), "basic", overwrite_existing=True) + rewriter = db.query("+basic") + rewriter.rewrite(fg) + assert fg.counter == [1, 0] + + db.register("a", TestOverwrite2(), "basic", overwrite_existing=True) + rewriter = db.query("+basic") + rewriter.rewrite(fg) + assert fg.counter == [1, 1] + def test_EquilibriumDB(self): eq_db = EquilibriumDB()