@@ -310,25 +310,29 @@ class EquilibriumDB(RewriteDatabase):
310
310
"""
311
311
312
312
def __init__ (
313
- self , ignore_newtrees : bool = True , tracks_on_change_inputs : bool = False
313
+ self ,
314
+ ignore_newtrees : bool = True ,
315
+ tracks_on_change_inputs : bool = False ,
316
+ eq_rewriter_class = pytensor_rewriting .EquilibriumGraphRewriter ,
314
317
):
315
318
"""
316
319
317
320
Parameters
318
321
----------
319
322
ignore_newtrees
320
- If ``False``, apply rewrites to new nodes introduced during
321
- rewriting.
322
-
323
+ If ``False``, apply rewrites to new nodes introduced during rewritings.
323
324
tracks_on_change_inputs
324
325
If ``True``, re-apply rewrites on nodes with changed inputs.
326
+ eq_rewriter_class: EquilibriumGraphRewriter class, optional
327
+ The class used to create the equilibrium rewriter. Defaults to EquilibriumGraphRewriter.
325
328
326
329
"""
327
330
super ().__init__ ()
328
331
self .ignore_newtrees = ignore_newtrees
329
332
self .tracks_on_change_inputs = tracks_on_change_inputs
330
333
self .__final__ : dict [str , bool ] = {}
331
334
self .__cleanup__ : dict [str , bool ] = {}
335
+ self .eq_rewriter_class = eq_rewriter_class
332
336
333
337
def register (
334
338
self ,
@@ -360,7 +364,7 @@ def query(self, *tags, **kwtags):
360
364
final_rewriters = None
361
365
if len (cleanup_rewriters ) == 0 :
362
366
cleanup_rewriters = None
363
- return pytensor_rewriting . EquilibriumGraphRewriter (
367
+ return self . eq_rewriter_class (
364
368
rewriters ,
365
369
max_use_ratio = config .optdb__max_use_ratio ,
366
370
ignore_newtrees = self .ignore_newtrees ,
0 commit comments