11"""Core graph classes."""
2- import abc
32import warnings
43from collections import deque
54from copy import copy
2625 Union ,
2726 cast ,
2827)
29- from weakref import WeakKeyDictionary
28+ from weakref import WeakValueDictionary
3029
3130import numpy as np
3231
3332from aesara .configdefaults import config
3433from aesara .graph .utils import (
35- MetaObject ,
3634 MethodNotDefined ,
3735 Scratchpad ,
3836 TestValueError ,
5351_TypeType = TypeVar ("_TypeType" , bound = "Type" )
5452_IdType = TypeVar ("_IdType" , bound = Hashable )
5553
56- T = TypeVar ("T" , bound = "Node" )
54+ T = TypeVar ("T" , bound = Union [ "Apply" , "Variable" ] )
5755NoParams = object ()
5856NodeAndChildren = Tuple [T , Optional [Iterable [T ]]]
5957
6058
61- class Node (MetaObject ):
62- r"""A `Node` in an Aesara graph.
59+ class UniqueInstanceFactory (type ):
6360
64- Currently, graphs contain two kinds of `Nodes`: `Variable`\s and `Apply`\s.
65- Edges in the graph are not explicitly represented. Instead each `Node`
66- keeps track of its parents via `Variable.owner` / `Apply.inputs`.
61+ __instances__ : WeakValueDictionary = WeakValueDictionary ()
6762
68- """
69- name : Optional [str ]
63+ def __call__ (
64+ cls ,
65+ * args ,
66+ ** kwargs ,
67+ ):
68+ idp = cls .create_key (* args , ** kwargs )
7069
71- def get_parents (self ):
72- """
73- Return a list of the parents of this node.
74- Should return a copy--i.e., modifying the return
75- value should not modify the graph structure.
70+ if idp not in cls .__instances__ :
71+ res = super (UniqueInstanceFactory , cls ).__call__ (* args , ** kwargs )
72+ cls .__instances__ [idp ] = res
73+ return res
7674
77- """
78- raise NotImplementedError ()
75+ return cls .__instances__ [idp ]
7976
8077
81- class Apply (Node , Generic [OpType ]):
78+ class Apply (
79+ Generic [OpType ],
80+ metaclass = UniqueInstanceFactory ,
81+ ):
8282 """A `Node` representing the application of an operation to inputs.
8383
8484 Basically, an `Apply` instance is an object that represents the
@@ -113,12 +113,19 @@ class Apply(Node, Generic[OpType]):
113113
114114 """
115115
116+ __slots__ = ("op" , "inputs" , "outputs" , "__weakref__" , "tag" )
117+
118+ @classmethod
119+ def create_key (cls , op , inputs , outputs ):
120+ return (op ,) + tuple (inputs )
121+
116122 def __init__ (
117123 self ,
118124 op : OpType ,
119125 inputs : Sequence ["Variable" ],
120126 outputs : Sequence ["Variable" ],
121127 ):
128+
122129 if not isinstance (inputs , Sequence ):
123130 raise TypeError ("The inputs of an Apply must be a sequence type" )
124131
@@ -154,6 +161,21 @@ def __init__(
154161 f"The 'outputs' argument to Apply must contain Variable instances with no owner, not { output } "
155162 )
156163
164+ def __eq__ (self , other ):
165+ if isinstance (other , type (self )):
166+ if (
167+ self .op == other .op
168+ and self .inputs == other .inputs
169+ # and self.outputs == other.outputs
170+ ):
171+ return True
172+ return False
173+
174+ return NotImplemented
175+
176+ def __hash__ (self ):
177+ return hash ((type (self ), self .op , tuple (self .inputs ), tuple (self .outputs )))
178+
157179 def run_params (self ):
158180 """
159181 Returns the params for the node, or NoParams if no params is set.
@@ -165,15 +187,19 @@ def run_params(self):
165187 return NoParams
166188
167189 def __getstate__ (self ):
168- d = self .__dict__
169- # ufunc don't pickle/unpickle well
190+ d = {k : getattr (self , k ) for k in self .__slots__ if k not in ("__weakref__" ,)}
170191 if hasattr (self .tag , "ufunc" ):
171192 d = copy (self .__dict__ )
172193 t = d ["tag" ]
173194 del t .ufunc
174195 d ["tag" ] = t
175196 return d
176197
198+ def __setstate__ (self , dct ):
199+ for k in self .__slots__ :
200+ if k in dct :
201+ setattr (self , k , dct [k ])
202+
177203 def default_output (self ):
178204 """
179205 Returns the default output for this node.
@@ -267,6 +293,7 @@ def clone_with_new_inputs(
267293 from aesara .graph .op import HasInnerGraph
268294
269295 assert isinstance (inputs , (list , tuple ))
296+
270297 remake_node = False
271298 new_inputs : List ["Variable" ] = list (inputs )
272299 for i , (curr , new ) in enumerate (zip (self .inputs , new_inputs )):
@@ -280,17 +307,22 @@ def clone_with_new_inputs(
280307 else :
281308 remake_node = True
282309
283- if remake_node :
284- new_op = self .op
310+ new_op = self .op
285311
286- if isinstance (new_op , HasInnerGraph ) and clone_inner_graph : # type: ignore
287- new_op = new_op .clone () # type: ignore
312+ if isinstance (new_op , HasInnerGraph ) and clone_inner_graph : # type: ignore
313+ new_op = new_op .clone () # type: ignore
288314
315+ if remake_node :
289316 new_node = new_op .make_node (* new_inputs )
290317 new_node .tag = copy (self .tag ).__update__ (new_node .tag )
318+ elif new_op == self .op and new_inputs == self .inputs :
319+ new_node = self
291320 else :
292- new_node = self .clone (clone_inner_graph = clone_inner_graph )
293- new_node .inputs = new_inputs
321+ new_node = self .__class__ (
322+ new_op , new_inputs , [output .clone () for output in self .outputs ]
323+ )
324+ new_node .tag = copy (self .tag )
325+
294326 return new_node
295327
296328 def get_parents (self ):
@@ -316,7 +348,7 @@ def params_type(self):
316348 return self .op .params_type
317349
318350
319- class Variable (Node , Generic [_TypeType , OptionalApplyType ]):
351+ class Variable (Generic [_TypeType , OptionalApplyType ]):
320352 r"""
321353 A :term:`Variable` is a node in an expression graph that represents a
322354 variable.
@@ -411,7 +443,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
411443
412444 """
413445
414- # __slots__ = [' type', 'owner', 'index', 'name']
446+ __slots__ = ( "_owner" , "_index" , "name" , " type" , "__weakref__" , "tag" , "auto_name" )
415447 __count__ = count (0 )
416448
417449 _owner : OptionalApplyType
@@ -487,26 +519,17 @@ def __str__(self):
487519 else :
488520 return f"<{ self .type } >"
489521
490- def __repr_test_value__ (self ):
491- """Return a ``repr`` of the test value.
492-
493- Return a printable representation of the test value. It can be
494- overridden by classes with non printable test_value to provide a
495- suitable representation of the test_value.
496- """
497- return repr (self .get_test_value ())
498-
499522 def __repr__ (self , firstPass = True ):
500523 """Return a ``repr`` of the `Variable`.
501524
502- Return a printable name or description of the Variable. If
503- `` config.print_test_value`` is ``True`` it will also print the test
504- value, if any.
525+ Return a printable name or description of the ` Variable` . If
526+ `aesara. config.print_test_value` is ``True``, it will also print the
527+ test value, if any.
505528 """
506529 to_print = [str (self )]
507530 if config .print_test_value and firstPass :
508531 try :
509- to_print .append (self .__repr_test_value__ ( ))
532+ to_print .append (repr ( self .get_test_value () ))
510533 except TestValueError :
511534 pass
512535 return "\n " .join (to_print )
@@ -528,26 +551,6 @@ def clone(self):
528551 cp .tag = copy (self .tag )
529552 return cp
530553
531- def __lt__ (self , other ):
532- raise NotImplementedError (
533- "Subclasses of Variable must provide __lt__" , self .__class__ .__name__
534- )
535-
536- def __le__ (self , other ):
537- raise NotImplementedError (
538- "Subclasses of Variable must provide __le__" , self .__class__ .__name__
539- )
540-
541- def __gt__ (self , other ):
542- raise NotImplementedError (
543- "Subclasses of Variable must provide __gt__" , self .__class__ .__name__
544- )
545-
546- def __ge__ (self , other ):
547- raise NotImplementedError (
548- "Subclasses of Variable must provide __ge__" , self .__class__ .__name__
549- )
550-
551554 def get_parents (self ):
552555 if self .owner is not None :
553556 return [self .owner ]
@@ -605,7 +608,7 @@ def eval(self, inputs_to_values=None):
605608 return rval
606609
607610 def __getstate__ (self ):
608- d = self . __dict__ . copy ()
611+ d = { k : getattr ( self , k ) for k in self . __slots__ if k not in ( "__weakref__" ,)}
609612 d .pop ("_fn_cache" , None )
610613 if (not config .pickle_test_value ) and (hasattr (self .tag , "test_value" )):
611614 if not type (config ).pickle_test_value .is_default :
@@ -618,26 +621,24 @@ def __getstate__(self):
618621 d ["tag" ] = t
619622 return d
620623
624+ def __setstate__ (self , dct ):
625+ for k in self .__slots__ :
626+ if k in dct :
627+ setattr (self , k , dct [k ])
628+
621629
622630class AtomicVariable (Variable [_TypeType , None ]):
623631 """A node type that has no ancestors and should never be considered an input to a graph."""
624632
625633 def __init__ (self , type : _TypeType , ** kwargs ):
626634 super ().__init__ (type , None , None , ** kwargs )
627635
628- @abc .abstractmethod
629- def signature (self ):
630- ...
631-
632- def merge_signature (self ):
633- return self .signature ()
634-
635636 def equals (self , other ):
636637 """
637638 This does what `__eq__` would normally do, but `Variable` and `Apply`
638639 should always be hashable by `id`.
639640 """
640- return isinstance ( other , type ( self )) and self . signature () == other . signature ()
641+ return self == other
641642
642643 @property
643644 def owner (self ):
@@ -661,12 +662,15 @@ def index(self, value):
661662class NominalVariable (AtomicVariable [_TypeType ]):
662663 """A variable that enables alpha-equivalent comparisons."""
663664
664- __instances__ : WeakKeyDictionary [
665+ __instances__ : WeakValueDictionary [
665666 Tuple ["Type" , Hashable ], "NominalVariable"
666- ] = WeakKeyDictionary ()
667+ ] = WeakValueDictionary ()
667668
668669 def __new__ (cls , id : _IdType , typ : _TypeType , ** kwargs ):
669- if (typ , id ) not in cls .__instances__ :
670+
671+ idp = (typ , id )
672+
673+ if idp not in cls .__instances__ :
670674 var_type = typ .variable_type
671675 type_name = f"Nominal{ var_type .__name__ } "
672676
@@ -681,9 +685,9 @@ def _str(self):
681685 )
682686 res : NominalVariable = super ().__new__ (new_type )
683687
684- cls .__instances__ [( typ , id ) ] = res
688+ cls .__instances__ [idp ] = res
685689
686- return cls .__instances__ [( typ , id ) ]
690+ return cls .__instances__ [idp ]
687691
688692 def __init__ (self , id : _IdType , typ : _TypeType , ** kwargs ):
689693 self .id = id
@@ -708,11 +712,11 @@ def __hash__(self):
708712 def __repr__ (self ):
709713 return f"{ type (self ).__name__ } ({ repr (self .id )} , { repr (self .type )} )"
710714
711- def signature (self ) -> Tuple [_TypeType , _IdType ]:
712- return (self .type , self .id )
713715
714-
715- class Constant (AtomicVariable [_TypeType ]):
716+ class Constant (
717+ AtomicVariable [_TypeType ],
718+ metaclass = UniqueInstanceFactory ,
719+ ):
716720 """A `Variable` with a fixed `data` field.
717721
718722 `Constant` nodes make numerous optimizations possible (e.g. constant
@@ -725,19 +729,22 @@ class Constant(AtomicVariable[_TypeType]):
725729
726730 """
727731
728- # __slots__ = ['data']
732+ __slots__ = ("type" , "data" )
733+
734+ @classmethod
735+ def create_key (cls , type , data , * args , ** kwargs ):
736+ # TODO FIXME: This filters the data twice: once here, and again in
737+ # `cls.__init__`. This might not be a big deal, though.
738+ return (type , type .filter (data ))
729739
730740 def __init__ (self , type : _TypeType , data : Any , name : Optional [str ] = None ):
731- super () .__init__ (type , name = name )
741+ AtomicVariable .__init__ (self , type , name = name )
732742 self .data = type .filter (data )
733743 add_tag_trace (self )
734744
735745 def get_test_value (self ):
736746 return self .data
737747
738- def signature (self ):
739- return (self .type , self .data )
740-
741748 def __str__ (self ):
742749 if self .name is not None :
743750 return self .name
@@ -764,6 +771,15 @@ def owner(self, value) -> None:
764771 def value (self ):
765772 return self .data
766773
774+ def __hash__ (self ):
775+ return hash ((type (self ), self .type , self .data ))
776+
777+ def __eq__ (self , other ):
778+ if isinstance (other , type (self )):
779+ return self .type == other .type and self .data == other .data
780+
781+ return NotImplemented
782+
767783
768784def walk (
769785 nodes : Iterable [T ],
0 commit comments