1414from itertools import chain
1515from itertools import product as itertools_product
1616from logging import Logger
17- from typing import Optional
17+ from typing import TYPE_CHECKING , Optional , Union
1818from warnings import warn
1919
2020import numpy as np
21+ from typing_extensions import Literal
2122
2223import aesara
2324from aesara .compile .function .types import (
4243from aesara .utils import NoDuplicateOptWarningFilter , difference , get_unbound_function
4344
4445
45- __docformat__ = "restructuredtext en"
46+ if TYPE_CHECKING :
47+ from aesara .graph .basic import Apply
48+
4649_logger : Logger = logging .getLogger ("aesara.compile.debugmode" )
4750_logger .addFilter (NoDuplicateOptWarningFilter ())
4851
@@ -1109,43 +1112,32 @@ class _FunctionGraphEvent:
11091112
11101113 """
11111114
1112- kind = ""
1113- """
1114- One of 'import', 'change', 'prune'.
1115-
1116- """
1117-
1118- node = None
1119- """
1120- Either 'output' or an Apply instance.
1121-
1122- """
1123-
1124- op = None
1125- """Either 'output' or an Op instance"""
1115+ kind : Literal ["import" , "change" , "prune" ]
1116+ old_node : Optional [Union [Literal ["output" ], "Apply" ]]
1117+ new_node : Optional [Union [Literal ["output" ], "Apply" ]]
1118+ op : Optional [Union [Literal ["output" ], Op ]]
1119+ idx : Optional [int ]
1120+ reason : Optional [str ]
11261121
1127- idx = None
1128- """
1129- Change events involve an position index of the input variable.
1130-
1131- """
1132-
1133- reason = None
1134- """
1135- Change events sometimes have a reason.
1136-
1137- """
1138-
1139- def __init__ (self , kind , node , idx = None , reason = None ):
1122+ def __init__ (
1123+ self ,
1124+ kind : Literal ["import" , "change" , "prune" ],
1125+ old_node : Union [Literal ["output" ], "Apply" ],
1126+ new_node : Union [Literal ["output" ], "Apply" ] = None ,
1127+ idx : Optional [int ] = None ,
1128+ reason : Optional [str ] = None ,
1129+ ):
11401130 self .kind = kind
1141- if node == "output" :
1142- self .node = "output"
1131+ if old_node == "output" :
1132+ self .old_node = "output"
1133+ self .new_node = "output"
11431134 self .op = "output"
11441135 else :
1145- self .node = node
1146- self .op = node .op
1136+ self .old_node = old_node
1137+ self .new_node = new_node
1138+ self .op = old_node .op
11471139 self .idx = idx
1148- self .reason = str (reason )
1140+ self .reason = str (reason ) if reason else None
11491141
11501142 def __str__ (self ):
11511143 if self .kind == "change" :
@@ -1219,21 +1211,21 @@ def on_attach(self, fgraph):
12191211 self .replaced_by = {}
12201212 self .event_list = []
12211213 for node in fgraph .toposort ():
1222- self .on_import (fgraph , node , "on_attach" )
1214+ self .on_import (fgraph , node , reason = "on_attach" )
12231215
12241216 def on_detach (self , fgraph ):
12251217 assert fgraph is self .fgraph
12261218 self .fgraph = None
12271219
12281220 def on_prune (self , fgraph , node , reason ):
1229- self .event_list .append (_FunctionGraphEvent ("prune" , node , reason = str ( reason ) ))
1221+ self .event_list .append (_FunctionGraphEvent ("prune" , node , reason = reason ))
12301222 assert node in self .active_nodes
12311223 assert node not in self .inactive_nodes
12321224 self .active_nodes .remove (node )
12331225 self .inactive_nodes .add (node )
12341226
12351227 def on_import (self , fgraph , node , reason ):
1236- self .event_list .append (_FunctionGraphEvent ("import" , node , reason = str ( reason ) ))
1228+ self .event_list .append (_FunctionGraphEvent ("import" , node , reason = reason ))
12371229
12381230 assert node not in self .active_nodes
12391231 self .active_nodes .add (node )
@@ -1253,31 +1245,36 @@ def on_import(self, fgraph, node, reason):
12531245 self .reasons .setdefault (r , [])
12541246 self .replaced_by .setdefault (r , [])
12551247
1256- def on_change_input (self , fgraph , node , i , r , new_r , reason = None ):
1248+ def on_change_input (
1249+ self , fgraph , old_node , new_node , i , old_var , new_var , reason = None
1250+ ):
12571251 reason = str (reason )
12581252 self .event_list .append (
1259- _FunctionGraphEvent ("change" , node , reason = reason , idx = i )
1253+ _FunctionGraphEvent ("change" , old_node , new_node , idx = i , reason = reason )
12601254 )
12611255
1262- self .reasons .setdefault (new_r , [])
1263- self .replaced_by .setdefault (new_r , [])
1256+ self .on_import (fgraph , new_node , reason = reason )
1257+ self .on_prune (fgraph , old_node , reason = reason )
1258+
1259+ self .reasons .setdefault (new_var , [])
1260+ self .replaced_by .setdefault (new_var , [])
12641261
12651262 append_reason = True
1266- for tup in self .reasons [new_r ]:
1267- if tup [0 ] == reason and tup [1 ] is r :
1263+ for tup in self .reasons [new_var ]:
1264+ if tup [0 ] == reason and tup [1 ] is old_var :
12681265 append_reason = False
12691266
12701267 if append_reason :
12711268 # N.B. compute the debugprint now, because future
12721269 # optimizations will change the graph
12731270 done = dict ()
12741271 used_ids = dict ()
1275- self .reasons [new_r ].append (
1272+ self .reasons [new_var ].append (
12761273 (
12771274 reason ,
1278- r ,
1275+ old_var ,
12791276 _debugprint (
1280- r ,
1277+ old_var ,
12811278 prefix = " " ,
12821279 depth = 6 ,
12831280 file = StringIO (),
@@ -1286,7 +1283,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
12861283 used_ids = used_ids ,
12871284 ).getvalue (),
12881285 _debugprint (
1289- new_r ,
1286+ new_var ,
12901287 prefix = " " ,
12911288 depth = 6 ,
12921289 file = StringIO (),
@@ -1296,22 +1293,22 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
12961293 ).getvalue (),
12971294 )
12981295 )
1299- self .replaced_by [r ].append ((reason , new_r ))
1296+ self .replaced_by [old_var ].append ((reason , new_var ))
13001297
1301- if r in self .equiv :
1302- r_set = self .equiv [r ]
1298+ if old_var in self .equiv :
1299+ r_set = self .equiv [old_var ]
13031300 else :
1304- r_set = self .equiv .setdefault (r , {r })
1305- self .all_variables_ever .append (r )
1301+ r_set = self .equiv .setdefault (old_var , {old_var })
1302+ self .all_variables_ever .append (old_var )
13061303
1307- if new_r in self .equiv :
1308- new_r_set = self .equiv [new_r ]
1304+ if new_var in self .equiv :
1305+ new_r_set = self .equiv [new_var ]
13091306 else :
1310- new_r_set = self .equiv .setdefault (new_r , {new_r })
1311- self .all_variables_ever .append (new_r )
1307+ new_r_set = self .equiv .setdefault (new_var , {new_var })
1308+ self .all_variables_ever .append (new_var )
13121309
1313- assert new_r in new_r_set
1314- assert r in r_set
1310+ assert new_var in new_r_set
1311+ assert old_var in r_set
13151312
13161313 # update one equivalence set to contain the other
13171314 # transfer all the elements of the old one to the new one
@@ -1320,8 +1317,8 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
13201317 self .equiv [like_new_r ] = r_set
13211318 assert like_new_r in r_set
13221319
1323- assert self .equiv [r ] is r_set
1324- assert self .equiv [new_r ] is r_set
1320+ assert self .equiv [old_var ] is r_set
1321+ assert self .equiv [new_var ] is r_set
13251322
13261323 def printstuff (self ):
13271324 for key in self .equiv :
0 commit comments