Skip to content

Commit 7462fdf

Browse files
committed
Cache unique value of TensorConstants and deprecate get_unique_constant_value
1 parent a377c22 commit 7462fdf

File tree

6 files changed

+54
-57
lines changed

6 files changed

+54
-57
lines changed

pytensor/scan/rewriting.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
get_slice_elements,
7272
set_subtensor,
7373
)
74-
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
74+
from pytensor.tensor.variable import TensorConstant
7575

7676

7777
list_opt_slice = [
@@ -136,10 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
136136
all_ins = list(graph_inputs(op_outs))
137137
for idx in range(op_info.n_seqs):
138138
node_inp = node.inputs[idx + 1]
139-
if (
140-
isinstance(node_inp, TensorConstant)
141-
and get_unique_constant_value(node_inp) is not None
142-
):
139+
if isinstance(node_inp, TensorConstant) and node_inp.unique_value is not None:
143140
try:
144141
# This works if input is a constant that has all entries
145142
# equal

pytensor/tensor/basic.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import pytensor
2121
import pytensor.scalar.sharedvar
22-
from pytensor import compile, config, printing
22+
from pytensor import config, printing
2323
from pytensor import scalar as ps
2424
from pytensor.compile.builders import OpFromGraph
2525
from pytensor.gradient import DisconnectedType, grad_undefined
@@ -74,7 +74,6 @@
7474
from pytensor.tensor.variable import (
7575
TensorConstant,
7676
TensorVariable,
77-
get_unique_constant_value,
7877
)
7978

8079

@@ -319,6 +318,8 @@ def get_underlying_scalar_constant_value(
319318
but I'm not sure where it is.
320319
321320
"""
321+
from pytensor.compile.ops import DeepCopyOp, OutputGuard
322+
322323
v = orig_v
323324
while True:
324325
if v is None:
@@ -336,34 +337,19 @@ def get_underlying_scalar_constant_value(
336337
raise NotScalarConstantError()
337338

338339
if isinstance(v, Constant):
339-
unique_value = get_unique_constant_value(v)
340-
if unique_value is not None:
341-
data = unique_value
342-
else:
343-
data = v.data
344-
345-
if isinstance(data, np.ndarray):
346-
try:
347-
return np.array(data.item(), dtype=v.dtype)
348-
except ValueError:
349-
raise NotScalarConstantError()
340+
if isinstance(v, TensorConstant) and v.unique_value is not None:
341+
return v.unique_value
350342

351-
from pytensor.sparse.type import SparseTensorType
343+
elif isinstance(v, ScalarConstant):
344+
return v.data
352345

353-
if isinstance(v.type, SparseTensorType):
354-
raise NotScalarConstantError()
355-
356-
return data
346+
raise NotScalarConstantError()
357347

358348
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
359349
max_recur -= 1
360350
if isinstance(
361351
v.owner.op,
362-
Alloc
363-
| DimShuffle
364-
| Unbroadcast
365-
| compile.ops.OutputGuard
366-
| compile.DeepCopyOp,
352+
Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp,
367353
):
368354
# OutputGuard is only used in debugmode but we
369355
# keep it here to avoid problems with old pickles

pytensor/tensor/rewriting/elemwise.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
register_specialize,
4242
)
4343
from pytensor.tensor.shape import shape_padleft
44-
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
44+
from pytensor.tensor.variable import TensorConstant
4545

4646

4747
class InplaceElemwiseOptimizer(GraphRewriter):
@@ -513,7 +513,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
513513
new_inputs.append(i)
514514
else:
515515
try:
516-
# works only for scalars
517516
cval_i = get_underlying_scalar_constant_value(
518517
i, only_process_constants=True
519518
)
@@ -1216,11 +1215,13 @@ def local_inline_composite_constants(fgraph, node):
12161215
inner_replacements = {}
12171216
for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs):
12181217
# Complex variables don't have a `c_literal` that can be inlined
1219-
if "complex" not in outer_inp.type.dtype:
1220-
unique_value = get_unique_constant_value(outer_inp)
1221-
if unique_value is not None:
1218+
if (
1219+
isinstance(outer_inp, TensorConstant)
1220+
and "complex" not in outer_inp.type.dtype
1221+
):
1222+
if outer_inp.unique_value is not None:
12221223
inner_replacements[inner_inp] = ps.constant(
1223-
unique_value, dtype=inner_inp.dtype
1224+
outer_inp.unique_value, dtype=inner_inp.dtype
12241225
)
12251226
continue
12261227
new_outer_inputs.append(outer_inp)

pytensor/tensor/rewriting/math.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@
105105
from pytensor.tensor.variable import (
106106
TensorConstant,
107107
TensorVariable,
108-
get_unique_constant_value,
109108
)
110109

111110

@@ -137,16 +136,8 @@ def get_constant(v):
137136
numeric constant. If v is a plain Variable, returns None.
138137
139138
"""
140-
if isinstance(v, Constant):
141-
unique_value = get_unique_constant_value(v)
142-
if unique_value is not None:
143-
data = unique_value
144-
else:
145-
data = v.data
146-
if data.ndim == 0:
147-
return data
148-
else:
149-
return None
139+
if isinstance(v, TensorConstant):
140+
return v.unique_value
150141
elif isinstance(v, Variable):
151142
return None
152143
else:

pytensor/tensor/variable.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Constant, OptionalApplyType, Variable
1313
from pytensor.graph.utils import MetaType
14-
from pytensor.scalar import ComplexError, IntegerDivisionError
14+
from pytensor.scalar import (
15+
ComplexError,
16+
IntegerDivisionError,
17+
)
1518
from pytensor.tensor import _get_vector_length
1619
from pytensor.tensor.exceptions import AdvancedIndexingError
1720
from pytensor.tensor.type import TensorType
@@ -1042,15 +1045,9 @@ def no_nan(self):
10421045

10431046
def get_unique_constant_value(x: TensorVariable) -> Number | None:
10441047
"""Return the unique value of a tensor, if there is one"""
1045-
if isinstance(x, Constant):
1046-
data = x.data
1047-
1048-
if isinstance(data, np.ndarray) and data.ndim > 0:
1049-
flat_data = data.ravel()
1050-
if flat_data.shape[0]:
1051-
if (flat_data == flat_data[0]).all():
1052-
return flat_data[0]
1053-
1048+
warnings.warn("get_unique_constant_value is deprecated.", FutureWarning)
1049+
if isinstance(x, TensorConstant):
1050+
return x.unique_value
10541051
return None
10551052

10561053

@@ -1077,6 +1074,30 @@ def __init__(self, type: _TensorTypeType, data, name=None):
10771074
def signature(self):
10781075
return TensorConstantSignature((self.type, self.data))
10791076

1077+
@property
1078+
def unique_value(self) -> Number | None:
1079+
"""Return the unique value of a tensor, if there is one"""
1080+
try:
1081+
return self._unique_value
1082+
except AttributeError:
1083+
data = self.data
1084+
if np.ndim(data) == 0:
1085+
unique_value = data
1086+
else:
1087+
flat_data = data.ravel()
1088+
if (flat_data == flat_data[0]).all():
1089+
unique_value = flat_data[0]
1090+
else:
1091+
unique_value = None
1092+
1093+
if unique_value is not None:
1094+
# Don't allow the unique value to be changed
1095+
unique_value.setflags(write=False)
1096+
1097+
self._unique_value = unique_value
1098+
1099+
return self._unique_value
1100+
10801101
def equals(self, other):
10811102
# Override Constant.equals to allow to compare with
10821103
# numpy.ndarray, and python type.

tests/tensor/test_basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3564,11 +3564,12 @@ def test_second(self):
35643564
assert get_underlying_scalar_constant_value(s) == c.data
35653565

35663566
def test_copy(self):
3567-
# Make sure we do not return the internal storage of a constant,
3567+
# Make sure we do not return a writeable internal storage of a constant,
35683568
# so we cannot change the value of a constant by mistake.
35693569
c = constant(3)
35703570
d = extract_constant(c)
3571-
d += 1
3571+
with pytest.raises(ValueError, match="output array is read-only"):
3572+
d += 1
35723573
e = extract_constant(c)
35733574
assert e == 3, (c, d, e)
35743575

0 commit comments

Comments
 (0)