Skip to content

Commit 08eaf0e

Browse files
brandonwillardricardoV94
authored andcommitted
Use singledispatch to register SharedVariable type constructors
1 parent 461832c commit 08eaf0e

File tree

4 files changed

+41
-84
lines changed

4 files changed

+41
-84
lines changed

pytensor/compile/sharedvalue.py

Lines changed: 15 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import copy
44
from contextlib import contextmanager
5+
from functools import singledispatch
56
from typing import List, Optional
67

78
from pytensor.graph.basic import Variable
@@ -157,14 +158,6 @@ def default_update(self, value):
157158
self._default_update = value
158159

159160

160-
def shared_constructor(ctor, remove=False):
161-
if remove:
162-
shared.constructors.remove(ctor)
163-
else:
164-
shared.constructors.append(ctor)
165-
return ctor
166-
167-
168161
def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
169162
r"""Create a `SharedVariable` initialized with a copy or reference of `value`.
170163
@@ -193,53 +186,26 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
193186
194187
"""
195188

196-
try:
197-
if isinstance(value, Variable):
198-
raise TypeError(
199-
"Shared variable constructor needs numeric "
200-
"values and not symbolic variables."
201-
)
202-
203-
for ctor in reversed(shared.constructors):
204-
try:
205-
var = ctor(
206-
value,
207-
name=name,
208-
strict=strict,
209-
allow_downcast=allow_downcast,
210-
**kwargs,
211-
)
212-
add_tag_trace(var)
213-
return var
214-
except TypeError:
215-
continue
216-
# This may happen when kwargs were supplied
217-
# if kwargs were given, the generic_constructor won't be callable.
218-
#
219-
# This was done on purpose, the rationale being that if kwargs
220-
# were supplied, the user didn't want them to be ignored.
189+
if isinstance(value, Variable):
190+
raise TypeError("Shared variable values can not be symbolic.")
221191

192+
try:
193+
var = shared_constructor(
194+
value,
195+
name=name,
196+
strict=strict,
197+
allow_downcast=allow_downcast,
198+
**kwargs,
199+
)
200+
add_tag_trace(var)
201+
return var
222202
except MemoryError as e:
223203
e.args = e.args + ("Consider using `pytensor.shared(..., borrow=True)`",)
224204
raise
225205

226-
raise TypeError(
227-
"No suitable SharedVariable constructor could be found."
228-
" Are you sure all kwargs are supported?"
229-
" We do not support the parameter dtype or type."
230-
f' value="{value}". parameters="{kwargs}"'
231-
)
232-
233-
234-
shared.constructors = []
235206

236-
237-
@shared_constructor
238-
def generic_constructor(value, name=None, strict=False, allow_downcast=None):
239-
"""
240-
SharedVariable Constructor.
241-
242-
"""
207+
@singledispatch
208+
def shared_constructor(value, name=None, strict=False, allow_downcast=None, **kwargs):
243209
return SharedVariable(
244210
type=generic,
245211
value=value,

pytensor/sparse/sharedvar.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,18 @@ class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable):
1111
format = property(lambda self: self.type.format)
1212

1313

14-
@shared_constructor
14+
@shared_constructor.register(scipy.sparse.spmatrix)
1515
def sparse_constructor(
1616
value, name=None, strict=False, allow_downcast=None, borrow=False, format=None
1717
):
18-
if not isinstance(value, scipy.sparse.spmatrix):
19-
raise TypeError(
20-
"Expected a sparse matrix in the sparse shared variable constructor. Received: ",
21-
value.__class__,
22-
)
23-
2418
if format is None:
2519
format = value.format
20+
2621
type = SparseTensorType(format=format, dtype=value.dtype)
22+
2723
if not borrow:
2824
value = copy.deepcopy(value)
25+
2926
return SparseTensorSharedVariable(
3027
type=type, value=value, name=name, strict=strict, allow_downcast=allow_downcast
3128
)

pytensor/tensor/random/var.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def __str__(self):
1818
)
1919

2020

21-
@shared_constructor
21+
@shared_constructor.register(np.random.RandomState)
22+
@shared_constructor.register(np.random.Generator)
2223
def randomgen_constructor(
2324
value, name=None, strict=False, allow_downcast=None, borrow=False
2425
):
@@ -29,8 +30,6 @@ def randomgen_constructor(
2930
elif isinstance(value, np.random.Generator):
3031
rng_sv_type = RandomGeneratorSharedVariable
3132
rng_type = random_generator_type
32-
else:
33-
raise TypeError()
3433

3534
if not borrow:
3635
value = copy.deepcopy(value)

pytensor/tensor/sharedvar.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import traceback
21
import warnings
32

43
import numpy as np
@@ -30,7 +29,7 @@ def _get_vector_length_TensorSharedVariable(var_inst, var):
3029
return len(var.get_value(borrow=True))
3130

3231

33-
@shared_constructor
32+
@shared_constructor.register(np.ndarray)
3433
def tensor_constructor(
3534
value,
3635
name=None,
@@ -60,14 +59,13 @@ def tensor_constructor(
6059
if target != "cpu":
6160
raise TypeError("not for cpu")
6261

63-
if not isinstance(value, np.ndarray):
64-
raise TypeError()
65-
6662
# If no shape is given, then the default is to assume that the value might
6763
# be resized in any dimension in the future.
6864
if shape is None:
69-
shape = (None,) * len(value.shape)
65+
shape = (None,) * value.ndim
66+
7067
type = TensorType(value.dtype, shape=shape)
68+
7169
return TensorSharedVariable(
7270
type=type,
7371
value=np.array(value, copy=(not borrow)),
@@ -81,7 +79,10 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
8179
pass
8280

8381

84-
@shared_constructor
82+
@shared_constructor.register(np.number)
83+
@shared_constructor.register(float)
84+
@shared_constructor.register(int)
85+
@shared_constructor.register(complex)
8586
def scalar_constructor(
8687
value, name=None, strict=False, allow_downcast=None, borrow=False, target="cpu"
8788
):
@@ -101,28 +102,22 @@ def scalar_constructor(
101102
if target != "cpu":
102103
raise TypeError("not for cpu")
103104

104-
if not isinstance(value, (np.number, float, int, complex)):
105-
raise TypeError()
106105
try:
107106
dtype = value.dtype
108-
except Exception:
107+
except AttributeError:
109108
dtype = np.asarray(value).dtype
110109

111110
dtype = str(dtype)
112111
value = _asarray(value, dtype=dtype)
113-
tensor_type = TensorType(dtype=str(value.dtype), shape=[])
112+
tensor_type = TensorType(dtype=str(value.dtype), shape=())
114113

115-
try:
116-
# Do not pass the dtype to asarray because we want this to fail if
117-
# strict is True and the types do not match.
118-
rval = ScalarSharedVariable(
119-
type=tensor_type,
120-
value=np.array(value, copy=True),
121-
name=name,
122-
strict=strict,
123-
allow_downcast=allow_downcast,
124-
)
125-
return rval
126-
except Exception:
127-
traceback.print_exc()
128-
raise
114+
# Do not pass the dtype to asarray because we want this to fail if
115+
# strict is True and the types do not match.
116+
rval = ScalarSharedVariable(
117+
type=tensor_type,
118+
value=np.array(value, copy=True),
119+
name=name,
120+
strict=strict,
121+
allow_downcast=allow_downcast,
122+
)
123+
return rval

0 commit comments

Comments
 (0)