Skip to content

Commit 5cfd9da

Browse files
committed
Deprecate ScalarSharedVariable
1 parent 1091b44 commit 5cfd9da

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

pytensor/tensor/sharedvar.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@
99
from pytensor.tensor.variable import _tensor_py_operators
1010

1111

12+
def __getattr__(name):
13+
if name == "ScalarSharedVariable":
14+
warnings.warn(
15+
"The class `ScalarSharedVariable` has been deprecated. "
16+
"Use `TensorSharedVariable` instead and check for `ndim==0`.",
17+
FutureWarning,
18+
)
19+
return TensorSharedVariable
20+
21+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
22+
23+
1224
def load_shared_variable(val):
1325
"""
1426
This function is only here to keep some pickles loading
@@ -94,10 +106,6 @@ def tensor_constructor(
94106
)
95107

96108

97-
class ScalarSharedVariable(TensorSharedVariable):
98-
pass
99-
100-
101109
@shared_constructor.register(np.number)
102110
@shared_constructor.register(float)
103111
@shared_constructor.register(int)
@@ -132,7 +140,7 @@ def scalar_constructor(
132140

133141
# Do not pass the dtype to asarray because we want this to fail if
134142
# strict is True and the types do not match.
135-
rval = ScalarSharedVariable(
143+
rval = TensorSharedVariable(
136144
type=tensor_type,
137145
value=np.array(value, copy=True),
138146
name=name,

tests/tensor/test_sharedvar.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pytensor.tensor import get_vector_length
1111
from pytensor.tensor.basic import MakeVector
1212
from pytensor.tensor.shape import Shape_i, specify_shape
13-
from pytensor.tensor.sharedvar import ScalarSharedVariable, TensorSharedVariable
13+
from pytensor.tensor.sharedvar import TensorSharedVariable
1414
from tests import unittest_tools as utt
1515

1616

@@ -679,12 +679,17 @@ def test_tensor_shared_zero():
679679

680680
def test_scalar_shared_options():
681681
res = pytensor.shared(value=np.float32(0.0), name="lk", borrow=True)
682-
assert isinstance(res, ScalarSharedVariable)
682+
assert isinstance(res, TensorSharedVariable) and res.type.ndim == 0
683683
assert res.type.dtype == "float32"
684684
assert res.name == "lk"
685685
assert res.type.shape == ()
686686

687687

688+
def test_scalar_shared_deprecated():
689+
with pytest.warns(FutureWarning, match=".*deprecated.*"):
690+
pytensor.tensor.sharedvar.ScalarSharedVariable
691+
692+
688693
def test_get_vector_length():
689694
x = pytensor.shared(np.array((2, 3, 4, 5)))
690695
assert get_vector_length(x) == 4

0 commit comments

Comments
 (0)