Skip to content

Commit 461832c

Browse files
brandonwillardricardoV94
authored andcommitted
Clean up docstrings and errors relating to SharedVariable
1 parent b132036 commit 461832c

File tree

6 files changed

+53
-90
lines changed

6 files changed

+53
-90
lines changed

doc/library/compile/shared.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
.. class:: SharedVariable
1515

16-
Variable with Storage that is shared between functions that it appears in.
16+
Variable with storage that is shared between the compiled functions that it appears in.
1717
These variables are meant to be created by registered *shared constructors*
1818
(see :func:`shared_constructor`).
1919

@@ -68,18 +68,17 @@
6868

6969
A container to use for this SharedVariable when it is an implicit function parameter.
7070

71-
:type: class:`Container`
7271

7372
.. autofunction:: shared
7473

7574
.. function:: shared_constructor(ctor)
7675

7776
Append `ctor` to the list of shared constructors (see :func:`shared`).
7877

79-
Each registered constructor ``ctor`` will be called like this:
78+
Each registered constructor `ctor` will be called like this:
8079

8180
.. code-block:: python
8281
8382
ctor(value, name=name, strict=strict, **kwargs)
8483
85-
If it do not support given value, it must raise a TypeError.
84+
If it do not support given value, it must raise a `TypeError`.

pytensor/compile/function/pfunc.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,12 @@ def rebuild_collect_shared(
7878
shared_inputs = []
7979

8080
def clone_v_get_shared_updates(v, copy_inputs_over):
81-
"""
82-
Clones a variable and its inputs recursively until all are in clone_d.
83-
Also appends all shared variables met along the way to shared inputs,
84-
and their default_update (if applicable) to update_d and update_expr.
81+
r"""Clones a variable and its inputs recursively until all are in `clone_d`.
82+
83+
Also, it appends all `SharedVariable`\s met along the way to
84+
`shared_inputs` and their corresponding
85+
`SharedVariable.default_update`\s (when applicable) to `update_d` and
86+
`update_expr`.
8587
8688
"""
8789
# this co-recurses with clone_a
@@ -419,22 +421,24 @@ def construct_pfunc_ins_and_outs(
419421
givens = []
420422

421423
if not isinstance(params, (list, tuple)):
422-
raise Exception("in pfunc() the first argument must be a list or " "a tuple")
424+
raise TypeError("The `params` argument must be a list or a tuple")
423425

424426
if not isinstance(no_default_updates, bool) and not isinstance(
425427
no_default_updates, list
426428
):
427-
raise TypeError("no_default_update should be either a boolean or " "a list")
429+
raise TypeError("The `no_default_update` argument must be a boolean or list")
428430

429-
if len(updates) > 0 and any(
430-
isinstance(v, Variable) for v in iter_over_pairs(updates)
431+
if len(updates) > 0 and not all(
432+
isinstance(pair, (tuple, list))
433+
and len(pair) == 2
434+
and isinstance(pair[0], Variable)
435+
for pair in iter_over_pairs(updates)
431436
):
432-
raise ValueError(
433-
"The updates parameter must be an OrderedDict/dict or a list of "
434-
"lists/tuples with 2 elements"
437+
raise TypeError(
438+
"The `updates` parameter must be an ordered mapping or a list of pairs"
435439
)
436440

437-
# transform params into pytensor.compile.In objects.
441+
# Transform params into pytensor.compile.In objects.
438442
inputs = [
439443
_pfunc_param_to_in(p, allow_downcast=allow_input_downcast) for p in params
440444
]

pytensor/compile/sharedvalue.py

Lines changed: 11 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
"""
2-
Provide a simple user friendly API to PyTensor-managed memory.
3-
4-
"""
1+
"""Provide a simple user friendly API to PyTensor-managed memory."""
52

63
import copy
74
from contextlib import contextmanager
@@ -30,52 +27,14 @@ def collect_new_shareds():
3027

3128

3229
class SharedVariable(Variable):
33-
"""
34-
Variable that is (defaults to being) shared between functions that
35-
it appears in.
36-
37-
Parameters
38-
----------
39-
name : str
40-
The name for this variable (see `Variable`).
41-
type : str
42-
The type for this variable (see `Variable`).
43-
value
44-
A value to associate with this variable (a new container will be
45-
created).
46-
strict
47-
True : assignments to .value will not be cast or copied, so they must
48-
have the correct type.
49-
allow_downcast
50-
Only applies if `strict` is False.
51-
True : allow assigned value to lose precision when cast during
52-
assignment.
53-
False : never allow precision loss.
54-
None : only allow downcasting of a Python float to a scalar floatX.
55-
container
56-
The container to use for this variable. Illegal to pass this as well as
57-
a value.
58-
59-
Notes
60-
-----
61-
For more user-friendly constructor, see `shared`.
30+
"""Variable that is shared between compiled functions."""
6231

63-
"""
64-
65-
# Container object
66-
container = None
32+
container: Optional[Container] = None
6733
"""
6834
A container to use for this SharedVariable when it is an implicit
6935
function parameter.
70-
71-
:type: `Container`
7236
"""
7337

74-
# default_update
75-
# If this member is present, its value will be used as the "update" for
76-
# this Variable, unless another update value has been passed to "function",
77-
# or the "no_default_updates" list passed to "function" contains it.
78-
7938
def __init__(self, name, type, value, strict, allow_downcast=None, container=None):
8039
super().__init__(type=type, name=name, owner=None, index=None)
8140

@@ -207,37 +166,30 @@ def shared_constructor(ctor, remove=False):
207166

208167

209168
def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
210-
"""Return a SharedVariable Variable, initialized with a copy or
211-
reference of `value`.
169+
r"""Create a `SharedVariable` initialized with a copy or reference of `value`.
212170
213171
This function iterates over constructor functions to find a
214-
suitable SharedVariable subclass. The suitable one is the first
172+
suitable `SharedVariable` subclass. The suitable one is the first
215173
constructor that accept the given value. See the documentation of
216174
:func:`shared_constructor` for the definition of a constructor
217175
function.
218176
219177
This function is meant as a convenient default. If you want to use a
220-
specific shared variable constructor, consider calling it directly.
221-
222-
``pytensor.shared`` is a shortcut to this function.
223-
224-
.. attribute:: constructors
178+
specific constructor, consider calling it directly.
225179
226-
A list of shared variable constructors that will be tried in reverse
227-
order.
180+
`pytensor.shared` is a shortcut to this function.
228181
229182
Notes
230183
-----
231184
By passing kwargs, you effectively limit the set of potential constructors
232185
to those that can accept those kwargs.
233186
234-
Some shared variable have ``borrow`` as extra kwargs.
187+
Some shared variable have `borrow` as a kwarg.
235188
236-
Some shared variable have ``broadcastable`` as extra kwargs. As shared
189+
`SharedVariable`\s of `TensorType` have `broadcastable` as a kwarg. As shared
237190
variable shapes can change, all dimensions default to not being
238-
broadcastable, even if ``value`` has a shape of 1 along some dimension.
239-
This parameter allows you to create for example a `row` or `column` 2d
240-
tensor.
191+
broadcastable, even if `value` has a shape of 1 along some dimension.
192+
This parameter allows one to create for example a row or column tensor.
241193
242194
"""
243195

pytensor/tensor/random/var.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __str__(self):
2222
def randomgen_constructor(
2323
value, name=None, strict=False, allow_downcast=None, borrow=False
2424
):
25-
r"""`SharedVariable` Constructor for NumPy's `Generator` and/or `RandomState`."""
25+
r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`."""
2626
if isinstance(value, np.random.RandomState):
2727
rng_sv_type = RandomStateSharedVariable
2828
rng_type = random_state_type

pytensor/tensor/sharedvar.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ def tensor_constructor(
4141
target="cpu",
4242
broadcastable=None,
4343
):
44-
"""
45-
SharedVariable Constructor for TensorType.
44+
r"""`SharedVariable` constructor for `TensorType`\s.
4645
4746
Notes
4847
-----
@@ -64,9 +63,8 @@ def tensor_constructor(
6463
if not isinstance(value, np.ndarray):
6564
raise TypeError()
6665

67-
# if no shape is given, then the default is to assume that
68-
# the value might be resized in any dimension in the future.
69-
#
66+
# If no shape is given, then the default is to assume that the value might
67+
# be resized in any dimension in the future.
7068
if shape is None:
7169
shape = (None,) * len(value.shape)
7270
type = TensorType(value.dtype, shape=shape)
@@ -79,13 +77,6 @@ def tensor_constructor(
7977
)
8078

8179

82-
# TensorSharedVariable brings in the tensor operators, is not ideal, but works
83-
# as long as we don't do purely scalar-scalar operations
84-
# _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__
85-
#
86-
# N.B. THERE IS ANOTHER CLASS CALLED ScalarSharedVariable in the
87-
# pytensor.scalar.sharedvar file. It is not registered as a shared_constructor,
88-
# this one is.
8980
class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
9081
pass
9182

@@ -94,8 +85,9 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
9485
def scalar_constructor(
9586
value, name=None, strict=False, allow_downcast=None, borrow=False, target="cpu"
9687
):
97-
"""
98-
SharedVariable constructor for scalar values. Default: int64 or float64.
88+
"""`SharedVariable` constructor for scalar values.
89+
90+
Default: int64 or float64.
9991
10092
Notes
10193
-----

tests/compile/function/test_pfunc.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ def data_of(s):
3636

3737

3838
class TestPfunc:
39+
def test_errors(self):
40+
a = lscalar()
41+
b = shared(1)
42+
43+
with pytest.raises(TypeError):
44+
pfunc({a}, a + b)
45+
46+
with pytest.raises(TypeError):
47+
pfunc([a], a + b, no_default_updates=1)
48+
49+
with pytest.raises(TypeError):
50+
pfunc([a], a + b, updates=[{b, a}])
51+
52+
with pytest.raises(TypeError):
53+
pfunc([a], a + b, updates=[(1, b)])
54+
3955
def test_doc(self):
4056
# Ensure the code given in pfunc.txt works as expected
4157

0 commit comments

Comments
 (0)