File tree Expand file tree Collapse file tree 2 files changed +20
-7
lines changed Expand file tree Collapse file tree 2 files changed +20
-7
lines changed Original file line number Diff line number Diff line change 9
9
from pytensor .tensor .variable import _tensor_py_operators
10
10
11
11
12
+ def __getattr__ (name ):
13
+ if name == "ScalarSharedVariable" :
14
+ warnings .warn (
15
+ "The class `ScalarSharedVariable` has been deprecated. "
16
+ "Use `TensorSharedVariable` instead and check for `ndim==0`." ,
17
+ FutureWarning ,
18
+ )
19
+ return TensorSharedVariable
20
+
21
+ raise AttributeError (f"module { __name__ !r} has no attribute { name !r} " )
22
+
23
+
12
24
def load_shared_variable (val ):
13
25
"""
14
26
This function is only here to keep some pickles loading
@@ -94,10 +106,6 @@ def tensor_constructor(
94
106
)
95
107
96
108
97
- class ScalarSharedVariable (TensorSharedVariable ):
98
- pass
99
-
100
-
101
109
@shared_constructor .register (np .number )
102
110
@shared_constructor .register (float )
103
111
@shared_constructor .register (int )
@@ -132,7 +140,7 @@ def scalar_constructor(
132
140
133
141
# Do not pass the dtype to asarray because we want this to fail if
134
142
# strict is True and the types do not match.
135
- rval = ScalarSharedVariable (
143
+ rval = TensorSharedVariable (
136
144
type = tensor_type ,
137
145
value = np .array (value , copy = True ),
138
146
name = name ,
Original file line number Diff line number Diff line change 10
10
from pytensor .tensor import get_vector_length
11
11
from pytensor .tensor .basic import MakeVector
12
12
from pytensor .tensor .shape import Shape_i , specify_shape
13
- from pytensor .tensor .sharedvar import ScalarSharedVariable , TensorSharedVariable
13
+ from pytensor .tensor .sharedvar import TensorSharedVariable
14
14
from tests import unittest_tools as utt
15
15
16
16
@@ -679,12 +679,17 @@ def test_tensor_shared_zero():
679
679
680
680
def test_scalar_shared_options ():
681
681
res = pytensor .shared (value = np .float32 (0.0 ), name = "lk" , borrow = True )
682
- assert isinstance (res , ScalarSharedVariable )
682
+ assert isinstance (res , TensorSharedVariable ) and res . type . ndim == 0
683
683
assert res .type .dtype == "float32"
684
684
assert res .name == "lk"
685
685
assert res .type .shape == ()
686
686
687
687
688
+ def test_scalar_shared_deprecated ():
689
+ with pytest .warns (FutureWarning , match = ".*deprecated.*" ):
690
+ pytensor .tensor .sharedvar .ScalarSharedVariable
691
+
692
+
688
693
def test_get_vector_length ():
689
694
x = pytensor .shared (np .array ((2 , 3 , 4 , 5 )))
690
695
assert get_vector_length (x ) == 4
You can’t perform that action at this time.
0 commit comments