Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 679da16

Browse files
Hash cons Apply nodes and Constants
1 parent 0fb8055 commit 679da16

File tree

12 files changed

+339
-379
lines changed

12 files changed

+339
-379
lines changed

aesara/graph/basic.py

Lines changed: 101 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""Core graph classes."""
2-
import abc
32
import warnings
43
from collections import deque
54
from copy import copy
@@ -26,13 +25,12 @@
2625
Union,
2726
cast,
2827
)
29-
from weakref import WeakKeyDictionary
28+
from weakref import WeakValueDictionary
3029

3130
import numpy as np
3231

3332
from aesara.configdefaults import config
3433
from aesara.graph.utils import (
35-
MetaObject,
3634
MethodNotDefined,
3735
Scratchpad,
3836
TestValueError,
@@ -53,32 +51,34 @@
5351
_TypeType = TypeVar("_TypeType", bound="Type")
5452
_IdType = TypeVar("_IdType", bound=Hashable)
5553

56-
T = TypeVar("T", bound="Node")
54+
T = TypeVar("T", bound=Union["Apply", "Variable"])
5755
NoParams = object()
5856
NodeAndChildren = Tuple[T, Optional[Iterable[T]]]
5957

6058

61-
class Node(MetaObject):
62-
r"""A `Node` in an Aesara graph.
59+
class UniqueInstanceFactory(type):
6360

64-
Currently, graphs contain two kinds of `Nodes`: `Variable`\s and `Apply`\s.
65-
Edges in the graph are not explicitly represented. Instead each `Node`
66-
keeps track of its parents via `Variable.owner` / `Apply.inputs`.
61+
__instances__: WeakValueDictionary = WeakValueDictionary()
6762

68-
"""
69-
name: Optional[str]
63+
def __call__(
64+
cls,
65+
*args,
66+
**kwargs,
67+
):
68+
idp = cls.create_key(*args, **kwargs)
7069

71-
def get_parents(self):
72-
"""
73-
Return a list of the parents of this node.
74-
Should return a copy--i.e., modifying the return
75-
value should not modify the graph structure.
70+
if idp not in cls.__instances__:
71+
res = super(UniqueInstanceFactory, cls).__call__(*args, **kwargs)
72+
cls.__instances__[idp] = res
73+
return res
7674

77-
"""
78-
raise NotImplementedError()
75+
return cls.__instances__[idp]
7976

8077

81-
class Apply(Node, Generic[OpType]):
78+
class Apply(
79+
Generic[OpType],
80+
metaclass=UniqueInstanceFactory,
81+
):
8282
"""A `Node` representing the application of an operation to inputs.
8383
8484
Basically, an `Apply` instance is an object that represents the
@@ -113,12 +113,19 @@ class Apply(Node, Generic[OpType]):
113113
114114
"""
115115

116+
__slots__ = ("op", "inputs", "outputs", "__weakref__", "tag")
117+
118+
@classmethod
119+
def create_key(cls, op, inputs, outputs):
120+
return (op,) + tuple(inputs)
121+
116122
def __init__(
117123
self,
118124
op: OpType,
119125
inputs: Sequence["Variable"],
120126
outputs: Sequence["Variable"],
121127
):
128+
122129
if not isinstance(inputs, Sequence):
123130
raise TypeError("The inputs of an Apply must be a sequence type")
124131

@@ -154,6 +161,21 @@ def __init__(
154161
f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}"
155162
)
156163

164+
def __eq__(self, other):
165+
if isinstance(other, type(self)):
166+
if (
167+
self.op == other.op
168+
and self.inputs == other.inputs
169+
# and self.outputs == other.outputs
170+
):
171+
return True
172+
return False
173+
174+
return NotImplemented
175+
176+
def __hash__(self):
177+
return hash((type(self), self.op, tuple(self.inputs), tuple(self.outputs)))
178+
157179
def run_params(self):
158180
"""
159181
Returns the params for the node, or NoParams if no params is set.
@@ -165,15 +187,19 @@ def run_params(self):
165187
return NoParams
166188

167189
def __getstate__(self):
168-
d = self.__dict__
169-
# ufunc don't pickle/unpickle well
190+
d = {k: getattr(self, k) for k in self.__slots__ if k not in ("__weakref__",)}
170191
if hasattr(self.tag, "ufunc"):
171192
d = copy(self.__dict__)
172193
t = d["tag"]
173194
del t.ufunc
174195
d["tag"] = t
175196
return d
176197

198+
def __setstate__(self, dct):
199+
for k in self.__slots__:
200+
if k in dct:
201+
setattr(self, k, dct[k])
202+
177203
def default_output(self):
178204
"""
179205
Returns the default output for this node.
@@ -267,6 +293,7 @@ def clone_with_new_inputs(
267293
from aesara.graph.op import HasInnerGraph
268294

269295
assert isinstance(inputs, (list, tuple))
296+
270297
remake_node = False
271298
new_inputs: List["Variable"] = list(inputs)
272299
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
@@ -280,17 +307,22 @@ def clone_with_new_inputs(
280307
else:
281308
remake_node = True
282309

283-
if remake_node:
284-
new_op = self.op
310+
new_op = self.op
285311

286-
if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore
287-
new_op = new_op.clone() # type: ignore
312+
if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore
313+
new_op = new_op.clone() # type: ignore
288314

315+
if remake_node:
289316
new_node = new_op.make_node(*new_inputs)
290317
new_node.tag = copy(self.tag).__update__(new_node.tag)
318+
elif new_op == self.op and new_inputs == self.inputs:
319+
new_node = self
291320
else:
292-
new_node = self.clone(clone_inner_graph=clone_inner_graph)
293-
new_node.inputs = new_inputs
321+
new_node = self.__class__(
322+
new_op, new_inputs, [output.clone() for output in self.outputs]
323+
)
324+
new_node.tag = copy(self.tag)
325+
294326
return new_node
295327

296328
def get_parents(self):
@@ -316,7 +348,7 @@ def params_type(self):
316348
return self.op.params_type
317349

318350

319-
class Variable(Node, Generic[_TypeType, OptionalApplyType]):
351+
class Variable(Generic[_TypeType, OptionalApplyType]):
320352
r"""
321353
A :term:`Variable` is a node in an expression graph that represents a
322354
variable.
@@ -411,7 +443,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
411443
412444
"""
413445

414-
# __slots__ = ['type', 'owner', 'index', 'name']
446+
__slots__ = ("_owner", "_index", "name", "type", "__weakref__", "tag", "auto_name")
415447
__count__ = count(0)
416448

417449
_owner: OptionalApplyType
@@ -487,26 +519,17 @@ def __str__(self):
487519
else:
488520
return f"<{self.type}>"
489521

490-
def __repr_test_value__(self):
491-
"""Return a ``repr`` of the test value.
492-
493-
Return a printable representation of the test value. It can be
494-
overridden by classes with non printable test_value to provide a
495-
suitable representation of the test_value.
496-
"""
497-
return repr(self.get_test_value())
498-
499522
def __repr__(self, firstPass=True):
500523
"""Return a ``repr`` of the `Variable`.
501524
502-
Return a printable name or description of the Variable. If
503-
``config.print_test_value`` is ``True`` it will also print the test
504-
value, if any.
525+
Return a printable name or description of the `Variable`. If
526+
`aesara.config.print_test_value` is ``True``, it will also print the
527+
test value, if any.
505528
"""
506529
to_print = [str(self)]
507530
if config.print_test_value and firstPass:
508531
try:
509-
to_print.append(self.__repr_test_value__())
532+
to_print.append(repr(self.get_test_value()))
510533
except TestValueError:
511534
pass
512535
return "\n".join(to_print)
@@ -528,26 +551,6 @@ def clone(self):
528551
cp.tag = copy(self.tag)
529552
return cp
530553

531-
def __lt__(self, other):
532-
raise NotImplementedError(
533-
"Subclasses of Variable must provide __lt__", self.__class__.__name__
534-
)
535-
536-
def __le__(self, other):
537-
raise NotImplementedError(
538-
"Subclasses of Variable must provide __le__", self.__class__.__name__
539-
)
540-
541-
def __gt__(self, other):
542-
raise NotImplementedError(
543-
"Subclasses of Variable must provide __gt__", self.__class__.__name__
544-
)
545-
546-
def __ge__(self, other):
547-
raise NotImplementedError(
548-
"Subclasses of Variable must provide __ge__", self.__class__.__name__
549-
)
550-
551554
def get_parents(self):
552555
if self.owner is not None:
553556
return [self.owner]
@@ -605,7 +608,7 @@ def eval(self, inputs_to_values=None):
605608
return rval
606609

607610
def __getstate__(self):
608-
d = self.__dict__.copy()
611+
d = {k: getattr(self, k) for k in self.__slots__ if k not in ("__weakref__",)}
609612
d.pop("_fn_cache", None)
610613
if (not config.pickle_test_value) and (hasattr(self.tag, "test_value")):
611614
if not type(config).pickle_test_value.is_default:
@@ -618,26 +621,24 @@ def __getstate__(self):
618621
d["tag"] = t
619622
return d
620623

624+
def __setstate__(self, dct):
625+
for k in self.__slots__:
626+
if k in dct:
627+
setattr(self, k, dct[k])
628+
621629

622630
class AtomicVariable(Variable[_TypeType, None]):
623631
"""A node type that has no ancestors and should never be considered an input to a graph."""
624632

625633
def __init__(self, type: _TypeType, **kwargs):
626634
super().__init__(type, None, None, **kwargs)
627635

628-
@abc.abstractmethod
629-
def signature(self):
630-
...
631-
632-
def merge_signature(self):
633-
return self.signature()
634-
635636
def equals(self, other):
636637
"""
637638
This does what `__eq__` would normally do, but `Variable` and `Apply`
638639
should always be hashable by `id`.
639640
"""
640-
return isinstance(other, type(self)) and self.signature() == other.signature()
641+
return self == other
641642

642643
@property
643644
def owner(self):
@@ -661,12 +662,15 @@ def index(self, value):
661662
class NominalVariable(AtomicVariable[_TypeType]):
662663
"""A variable that enables alpha-equivalent comparisons."""
663664

664-
__instances__: WeakKeyDictionary[
665+
__instances__: WeakValueDictionary[
665666
Tuple["Type", Hashable], "NominalVariable"
666-
] = WeakKeyDictionary()
667+
] = WeakValueDictionary()
667668

668669
def __new__(cls, id: _IdType, typ: _TypeType, **kwargs):
669-
if (typ, id) not in cls.__instances__:
670+
671+
idp = (typ, id)
672+
673+
if idp not in cls.__instances__:
670674
var_type = typ.variable_type
671675
type_name = f"Nominal{var_type.__name__}"
672676

@@ -681,9 +685,9 @@ def _str(self):
681685
)
682686
res: NominalVariable = super().__new__(new_type)
683687

684-
cls.__instances__[(typ, id)] = res
688+
cls.__instances__[idp] = res
685689

686-
return cls.__instances__[(typ, id)]
690+
return cls.__instances__[idp]
687691

688692
def __init__(self, id: _IdType, typ: _TypeType, **kwargs):
689693
self.id = id
@@ -708,11 +712,11 @@ def __hash__(self):
708712
def __repr__(self):
709713
return f"{type(self).__name__}({repr(self.id)}, {repr(self.type)})"
710714

711-
def signature(self) -> Tuple[_TypeType, _IdType]:
712-
return (self.type, self.id)
713715

714-
715-
class Constant(AtomicVariable[_TypeType]):
716+
class Constant(
717+
AtomicVariable[_TypeType],
718+
metaclass=UniqueInstanceFactory,
719+
):
716720
"""A `Variable` with a fixed `data` field.
717721
718722
`Constant` nodes make numerous optimizations possible (e.g. constant
@@ -725,19 +729,22 @@ class Constant(AtomicVariable[_TypeType]):
725729
726730
"""
727731

728-
# __slots__ = ['data']
732+
__slots__ = ("type", "data")
733+
734+
@classmethod
735+
def create_key(cls, type, data, *args, **kwargs):
736+
# TODO FIXME: This filters the data twice: once here, and again in
737+
# `cls.__init__`. This might not be a big deal, though.
738+
return (type, type.filter(data))
729739

730740
def __init__(self, type: _TypeType, data: Any, name: Optional[str] = None):
731-
super().__init__(type, name=name)
741+
AtomicVariable.__init__(self, type, name=name)
732742
self.data = type.filter(data)
733743
add_tag_trace(self)
734744

735745
def get_test_value(self):
736746
return self.data
737747

738-
def signature(self):
739-
return (self.type, self.data)
740-
741748
def __str__(self):
742749
if self.name is not None:
743750
return self.name
@@ -764,6 +771,15 @@ def owner(self, value) -> None:
764771
def value(self):
765772
return self.data
766773

774+
def __hash__(self):
775+
return hash((type(self), self.type, self.data))
776+
777+
def __eq__(self, other):
778+
if isinstance(other, type(self)):
779+
return self.type == other.type and self.data == other.data
780+
781+
return NotImplemented
782+
767783

768784
def walk(
769785
nodes: Iterable[T],

0 commit comments

Comments
 (0)