Skip to content

Commit b132036

Browse files
brandonwillardricardoV94
authored andcommitted
Make SharedVariable.default_update a persistent property
`Type` checks are also performed when values are assigned to the property.
1 parent 8968a38 commit b132036

File tree

5 files changed

+26
-49
lines changed

5 files changed

+26
-49
lines changed

pytensor/compile/function/pfunc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def clone_v_get_shared_updates(v, copy_inputs_over):
103103
elif isinstance(v, SharedVariable):
104104
if v not in shared_inputs:
105105
shared_inputs.append(v)
106-
if hasattr(v, "default_update"):
106+
if v.default_update is not None:
107107
# Check that v should not be excluded from the default
108108
# updates list
109109
if no_default_updates is False or (

pytensor/compile/sharedvalue.py

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from contextlib import contextmanager
88
from typing import List, Optional
99

10-
import numpy as np
11-
1210
from pytensor.graph.basic import Variable
1311
from pytensor.graph.utils import add_tag_trace
1412
from pytensor.link.basic import Container
@@ -103,6 +101,8 @@ def __init__(self, name, type, value, strict, allow_downcast=None, container=Non
103101
if isinstance(__SHARED_CONTEXT__, list):
104102
__SHARED_CONTEXT__.append(self)
105103

104+
self._default_update: Optional[Variable] = None
105+
106106
def get_value(self, borrow=False, return_internal_type=False):
107107
"""
108108
Get the non-symbolic value associated with this SharedVariable.
@@ -179,47 +179,23 @@ def clone(self, **kwargs):
179179
cp.tag = copy.copy(self.tag)
180180
return cp
181181

182-
def __getitem__(self, *args):
183-
# __getitem__ is not available for generic SharedVariable objects.
184-
# We raise a TypeError like Python would do if __getitem__ was not
185-
# implemented at all, but with a more explicit error message to help
186-
# PyTensor users figure out the root of the problem more easily.
187-
value = self.get_value(borrow=True)
188-
if isinstance(value, np.ndarray):
189-
# Array probably had an unknown dtype.
190-
msg = (
191-
f"a Numpy array with dtype: '{value.dtype}'. This data type is not "
192-
"currently recognized by PyTensor tensors: please cast "
193-
"your data into a supported numeric type if you need "
194-
"PyTensor tensor functionalities."
195-
)
196-
else:
197-
msg = (
198-
f"an object of type: {type(value)}. Did you forget to cast it into "
199-
"a Numpy array before calling pytensor.shared()?"
200-
)
201-
202-
raise TypeError(
203-
"The generic 'SharedVariable' object is not subscriptable. "
204-
f"This shared variable contains {msg}"
205-
)
206-
207-
def _value_get(self):
208-
raise Exception(
209-
"sharedvar.value does not exist anymore. Use "
210-
"sharedvar.get_value() or sharedvar.set_value()"
211-
" instead."
212-
)
182+
@property
183+
def default_update(self) -> Optional[Variable]:
184+
"""A default update expression for this `Variable`.
213185
214-
def _value_set(self, new_value):
215-
raise Exception(
216-
"sharedvar.value does not exist anymore. Use "
217-
"sharedvar.get_value() or sharedvar.set_value()"
218-
" instead."
219-
)
186+
If this value is non-``None``, its value will be used as the `update`
187+
(see `pytensor.function`) for this `Variable` when no updates are
188+
provided through `pytensor.function` and `no_default_updates` isn't
189+
enabled.
190+
"""
191+
return self._default_update
220192

221-
# We keep this just to raise an error
222-
value = property(_value_get, _value_set)
193+
@default_update.setter
194+
def default_update(self, value):
195+
if value is not None:
196+
self._default_update = self.type.filter_variable(value, allow_convert=True)
197+
else:
198+
self._default_update = value
223199

224200

225201
def shared_constructor(ctor, remove=False):

pytensor/scan/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -996,8 +996,8 @@ def wrap_into_list(x):
996996
# We also don't want to remove a default update that applies to
997997
# the scope/context containing this `Scan`, so we only remove
998998
# default updates on "local" variables.
999-
if is_local and hasattr(input.variable, "default_update"):
1000-
del input.variable.default_update
999+
if is_local and input.variable.default_update is not None:
1000+
input.variable.default_update = None
10011001

10021002
new_var = safe_new(input.variable)
10031003

tests/compile/function/test_pfunc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,8 @@ def test_default_updates(self):
432432
f()
433433
assert x.get_value() == 1
434434

435-
del x.default_update
435+
x.default_update = None
436+
436437
f()
437438
assert x.get_value() == 2
438439

tests/scan/test_basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,10 @@ def inner_fn(x):
282282
n_steps=4,
283283
)
284284

285-
assert not hasattr(inner_rng, "default_update")
286-
assert hasattr(inner_inner_rng, "default_update")
287-
assert hasattr(y, "default_update")
288-
assert hasattr(z_rng, "default_update")
285+
assert inner_rng is None
286+
assert inner_inner_rng.default_update is not None
287+
assert y.default_update is not None
288+
assert z_rng.default_update is not None
289289

290290
out_fn = function([], out, mode=Mode(optimizer=None))
291291
res, z_res = out_fn()

0 commit comments

Comments
 (0)