24
24
vars_between ,
25
25
)
26
26
from pytensor .graph .utils import MetaObject , MissingInputError , TestValueError
27
- from pytensor .misc .ordered_set import OrderedSet
28
27
29
28
30
29
ClientType = tuple [Apply , int ]
@@ -133,7 +132,6 @@ def __init__(
133
132
features = []
134
133
135
134
self ._features : list [Feature ] = []
136
-
137
135
# All apply nodes in the subgraph defined by inputs and
138
136
# outputs are cached in this field
139
137
self .apply_nodes : set [Apply ] = set ()
@@ -161,7 +159,8 @@ def __init__(
161
159
"input's owner or use graph.clone."
162
160
)
163
161
164
- self .add_input (in_var , check = False )
162
+ self .inputs .append (in_var )
163
+ self .clients .setdefault (in_var , [])
165
164
166
165
for output in outputs :
167
166
self .add_output (output , reason = "init" )
@@ -189,16 +188,6 @@ def add_input(self, var: Variable, check: bool = True) -> None:
189
188
return
190
189
191
190
self .inputs .append (var )
192
- self .setup_var (var )
193
-
194
- def setup_var (self , var : Variable ) -> None :
195
- """Set up a variable so it belongs to this `FunctionGraph`.
196
-
197
- Parameters
198
- ----------
199
- var : pytensor.graph.basic.Variable
200
-
201
- """
202
191
self .clients .setdefault (var , [])
203
192
204
193
def get_clients (self , var : Variable ) -> list [ClientType ]:
@@ -322,10 +311,11 @@ def import_var(
322
311
323
312
"""
324
313
# Imports the owners of the variables
325
- if var .owner and var .owner not in self .apply_nodes :
326
- self .import_node (var .owner , reason = reason , import_missing = import_missing )
314
+ apply = var .owner
315
+ if apply is not None and apply not in self .apply_nodes :
316
+ self .import_node (apply , reason = reason , import_missing = import_missing )
327
317
elif (
328
- var . owner is None
318
+ apply is None
329
319
and not isinstance (var , AtomicVariable )
330
320
and var not in self .inputs
331
321
):
@@ -336,10 +326,11 @@ def import_var(
336
326
f"Computation graph contains a NaN. { var .type .why_null } "
337
327
)
338
328
if import_missing :
339
- self .add_input (var )
329
+ self .inputs .append (var )
330
+ self .clients .setdefault (var , [])
340
331
else :
341
332
raise MissingInputError (f"Undeclared input: { var } " , variable = var )
342
- self .setup_var (var )
333
+ self .clients . setdefault (var , [] )
343
334
self .variables .add (var )
344
335
345
336
def import_node (
@@ -356,29 +347,29 @@ def import_node(
356
347
apply_node : Apply
357
348
The node to be imported.
358
349
check : bool
359
- Check that the inputs for the imported nodes are also present in
360
- the `FunctionGraph`.
350
+ Check that the inputs for the imported nodes are also present in the `FunctionGraph`.
361
351
reason : str
362
352
The name of the optimization or operation in progress.
363
353
import_missing : bool
364
354
Add missing inputs instead of raising an exception.
365
355
"""
366
356
# We import the nodes in topological order. We only are interested in
367
- # new nodes, so we use all variables we know of as if they were the
368
- # input set. (The functions in the graph module only use the input set
369
- # to know where to stop going down.)
370
- new_nodes = tuple ( toposort ( apply_node . outputs , blockers = self .variables ))
371
-
372
- if check :
373
- for node in new_nodes :
357
+ # new nodes, so we use all nodes we know of as inputs to interrupt the toposort
358
+ self_variables = self . variables
359
+ self_clients = self . clients
360
+ self_apply_nodes = self .apply_nodes
361
+ self_inputs = self . inputs
362
+ for node in toposort ( apply_node . outputs , blockers = self_variables ) :
363
+ if check :
374
364
for var in node .inputs :
375
365
if (
376
366
var .owner is None
377
367
and not isinstance (var , AtomicVariable )
378
- and var not in self . inputs
368
+ and var not in self_inputs
379
369
):
380
370
if import_missing :
381
- self .add_input (var )
371
+ self_inputs .append (var )
372
+ self_clients .setdefault (var , [])
382
373
else :
383
374
error_msg = (
384
375
f"Input { node .inputs .index (var )} ({ var } )"
@@ -390,20 +381,20 @@ def import_node(
390
381
)
391
382
raise MissingInputError (error_msg , variable = var )
392
383
393
- for node in new_nodes :
394
- assert node not in self . apply_nodes
395
- self . apply_nodes . add ( node )
396
- if not hasattr ( node . tag , " imported_by" ):
397
- node . tag . imported_by = []
398
- node . tag .imported_by .append (str (reason ))
384
+ self_apply_nodes . add ( node )
385
+ tag = node . tag
386
+ if not hasattr ( tag , "imported_by" ):
387
+ tag . imported_by = [ str ( reason )]
388
+ else :
389
+ tag .imported_by .append (str (reason ))
399
390
for output in node .outputs :
400
- self . setup_var (output )
401
- self . variables .add (output )
402
- for i , input in enumerate (node .inputs ):
403
- if input not in self . variables :
404
- self . setup_var ( input )
405
- self . variables . add (input )
406
- self . add_client ( input , (node , i ))
391
+ self_clients . setdefault (output , [] )
392
+ self_variables .add (output )
393
+ for i , inp in enumerate (node .inputs ):
394
+ if inp not in self_variables :
395
+ self_clients . setdefault ( inp , [] )
396
+ self_variables . add (inp )
397
+ self_clients [ inp ]. append ( (node , i ))
407
398
self .execute_callbacks ("on_import" , node , reason )
408
399
409
400
def change_node_input (
@@ -457,7 +448,7 @@ def change_node_input(
457
448
self .outputs [node .op .idx ] = new_var
458
449
459
450
self .import_var (new_var , reason = reason , import_missing = import_missing )
460
- self .add_client ( new_var , (node , i ))
451
+ self .clients [ new_var ]. append ( (node , i ))
461
452
self .remove_client (r , (node , i ), reason = reason )
462
453
# Precondition: the substitution is semantically valid However it may
463
454
# introduce cycles to the graph, in which case the transaction will be
@@ -756,10 +747,6 @@ def toposort(self) -> list[Apply]:
756
747
:meth:`FunctionGraph.orderings`.
757
748
758
749
"""
759
- if len (self .apply_nodes ) < 2 :
760
- # No sorting is necessary
761
- return list (self .apply_nodes )
762
-
763
750
return list (toposort_with_orderings (self .outputs , orderings = self .orderings ()))
764
751
765
752
def orderings (self ) -> dict [Apply , list [Apply ]]:
@@ -779,29 +766,17 @@ def orderings(self) -> dict[Apply, list[Apply]]:
779
766
take care of computing the dependencies by itself.
780
767
781
768
"""
782
- assert isinstance (self ._features , list )
783
- all_orderings : list [dict ] = []
769
+ all_orderings : list [dict ] = [
770
+ orderings
771
+ for feature in self ._features
772
+ if (
773
+ hasattr (feature , "orderings" ) and (orderings := feature .orderings (self ))
774
+ )
775
+ ]
784
776
785
- for feature in self ._features :
786
- if hasattr (feature , "orderings" ):
787
- orderings = feature .orderings (self )
788
- if not isinstance (orderings , dict ):
789
- raise TypeError (
790
- "Non-deterministic return value from "
791
- + str (feature .orderings )
792
- + ". Nondeterministic object is "
793
- + str (orderings )
794
- )
795
- if len (orderings ) > 0 :
796
- all_orderings .append (orderings )
797
- for node , prereqs in orderings .items ():
798
- if not isinstance (prereqs , list | OrderedSet ):
799
- raise TypeError (
800
- "prereqs must be a type with a "
801
- "deterministic iteration order, or toposort "
802
- " will be non-deterministic."
803
- )
804
- if len (all_orderings ) == 1 :
777
+ if not all_orderings :
778
+ return {}
779
+ elif len (all_orderings ) == 1 :
805
780
# If there is only 1 ordering, we reuse it directly.
806
781
return all_orderings [0 ].copy ()
807
782
else :
0 commit comments