Skip to content

Commit a120dc2

Browse files
committed
Cache unique value of TensorConstants and deprecate get_unique_constant_value
1 parent 2b57f74 commit a120dc2

File tree

8 files changed

+88
-84
lines changed

8 files changed

+88
-84
lines changed

pytensor/scan/rewriting.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
get_slice_elements,
7272
set_subtensor,
7373
)
74-
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
74+
from pytensor.tensor.variable import TensorConstant
7575

7676

7777
list_opt_slice = [
@@ -136,10 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
136136
all_ins = list(graph_inputs(op_outs))
137137
for idx in range(op_info.n_seqs):
138138
node_inp = node.inputs[idx + 1]
139-
if (
140-
isinstance(node_inp, TensorConstant)
141-
and get_unique_constant_value(node_inp) is not None
142-
):
139+
if isinstance(node_inp, TensorConstant) and node_inp.unique_value is not None:
143140
try:
144141
# This works if input is a constant that has all entries
145142
# equal

pytensor/sparse/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,10 @@ def __str__(self):
491491
def __repr__(self):
492492
return str(self)
493493

494+
@property
495+
def unique_value(self):
496+
return None
497+
494498

495499
SparseTensorType.variable_type = SparseVariable
496500
SparseTensorType.constant_type = SparseConstant

pytensor/tensor/basic.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import pytensor
2121
import pytensor.scalar.sharedvar
22-
from pytensor import compile, config, printing
22+
from pytensor import config, printing
2323
from pytensor import scalar as ps
2424
from pytensor.compile.builders import OpFromGraph
2525
from pytensor.gradient import DisconnectedType, grad_undefined
@@ -35,7 +35,7 @@
3535
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
3636
from pytensor.raise_op import CheckAndRaise, assert_op
3737
from pytensor.scalar import int32
38-
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
38+
from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable
3939
from pytensor.tensor import (
4040
_as_tensor_variable,
4141
_get_vector_length,
@@ -71,10 +71,10 @@
7171
uint_dtypes,
7272
values_eq_approx_always_true,
7373
)
74+
from pytensor.tensor.type_other import NoneTypeT
7475
from pytensor.tensor.variable import (
7576
TensorConstant,
7677
TensorVariable,
77-
get_unique_constant_value,
7878
)
7979

8080

@@ -319,6 +319,8 @@ def get_underlying_scalar_constant_value(
319319
but I'm not sure where it is.
320320
321321
"""
322+
from pytensor.compile.ops import DeepCopyOp, OutputGuard
323+
322324
v = orig_v
323325
while True:
324326
if v is None:
@@ -336,34 +338,22 @@ def get_underlying_scalar_constant_value(
336338
raise NotScalarConstantError()
337339

338340
if isinstance(v, Constant):
339-
unique_value = get_unique_constant_value(v)
340-
if unique_value is not None:
341-
data = unique_value
342-
else:
343-
data = v.data
344-
345-
if isinstance(data, np.ndarray):
346-
try:
347-
return np.array(data.item(), dtype=v.dtype)
348-
except ValueError:
349-
raise NotScalarConstantError()
341+
if isinstance(v.type, TensorType) and v.unique_value is not None:
342+
return v.unique_value
350343

351-
from pytensor.sparse.type import SparseTensorType
344+
elif isinstance(v.type, ScalarType):
345+
return v.data
352346

353-
if isinstance(v.type, SparseTensorType):
354-
raise NotScalarConstantError()
347+
elif isinstance(v.type, NoneTypeT):
348+
return None
355349

356-
return data
350+
raise NotScalarConstantError()
357351

358352
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
359353
max_recur -= 1
360354
if isinstance(
361355
v.owner.op,
362-
Alloc
363-
| DimShuffle
364-
| Unbroadcast
365-
| compile.ops.OutputGuard
366-
| compile.DeepCopyOp,
356+
Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp,
367357
):
368358
# OutputGuard is only used in debugmode but we
369359
# keep it here to avoid problems with old pickles

pytensor/tensor/rewriting/elemwise.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
register_specialize,
4242
)
4343
from pytensor.tensor.shape import shape_padleft
44-
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
44+
from pytensor.tensor.variable import TensorConstant
4545

4646

4747
class InplaceElemwiseOptimizer(GraphRewriter):
@@ -513,7 +513,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
513513
new_inputs.append(i)
514514
else:
515515
try:
516-
# works only for scalars
517516
cval_i = get_underlying_scalar_constant_value(
518517
i, only_process_constants=True
519518
)
@@ -1218,11 +1217,13 @@ def local_inline_composite_constants(fgraph, node):
12181217
node.inputs, composite_op.fgraph.inputs, strict=True
12191218
):
12201219
# Complex variables don't have a `c_literal` that can be inlined
1221-
if "complex" not in outer_inp.type.dtype:
1222-
unique_value = get_unique_constant_value(outer_inp)
1223-
if unique_value is not None:
1220+
if (
1221+
isinstance(outer_inp, TensorConstant)
1222+
and "complex" not in outer_inp.type.dtype
1223+
):
1224+
if outer_inp.unique_value is not None:
12241225
inner_replacements[inner_inp] = ps.constant(
1225-
unique_value, dtype=inner_inp.dtype
1226+
outer_inp.unique_value, dtype=inner_inp.dtype
12261227
)
12271228
continue
12281229
new_outer_inputs.append(outer_inp)

pytensor/tensor/rewriting/math.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@
106106
from pytensor.tensor.variable import (
107107
TensorConstant,
108108
TensorVariable,
109-
get_unique_constant_value,
110109
)
111110

112111

@@ -138,16 +137,8 @@ def get_constant(v):
138137
numeric constant. If v is a plain Variable, returns None.
139138
140139
"""
141-
if isinstance(v, Constant):
142-
unique_value = get_unique_constant_value(v)
143-
if unique_value is not None:
144-
data = unique_value
145-
else:
146-
data = v.data
147-
if data.ndim == 0:
148-
return data
149-
else:
150-
return None
140+
if isinstance(v, TensorConstant):
141+
return v.unique_value
151142
elif isinstance(v, Variable):
152143
return None
153144
else:
@@ -628,7 +619,14 @@ def local_mul_switch_sink(fgraph, node):
628619
# Look for a zero as the first or second branch of the switch
629620
for branch in range(2):
630621
zero_switch_input = switch_node.inputs[1 + branch]
631-
if not get_unique_constant_value(zero_switch_input) == 0.0:
622+
if (
623+
not get_underlying_scalar_constant_value(
624+
zero_switch_input,
625+
only_process_constants=True,
626+
raise_not_constant=False,
627+
)
628+
== 0.0
629+
):
632630
continue
633631

634632
switch_cond = switch_node.inputs[0]
@@ -685,7 +683,14 @@ def local_div_switch_sink(fgraph, node):
685683
# Look for a zero as the first or second branch of the switch
686684
for branch in range(2):
687685
zero_switch_input = switch_node.inputs[1 + branch]
688-
if not get_unique_constant_value(zero_switch_input) == 0.0:
686+
if (
687+
not get_underlying_scalar_constant_value(
688+
zero_switch_input,
689+
only_process_constants=True,
690+
raise_not_constant=False,
691+
)
692+
== 0.0
693+
):
689694
continue
690695

691696
switch_cond = switch_node.inputs[0]

pytensor/tensor/shape.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytensor.tensor.elemwise import get_normalized_batch_axes
2121
from pytensor.tensor.exceptions import NotScalarConstantError
2222
from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
23-
from pytensor.tensor.type_other import NoneConst
23+
from pytensor.tensor.type_other import NoneConst, NoneTypeT
2424
from pytensor.tensor.variable import TensorConstant, TensorVariable
2525

2626

@@ -401,8 +401,6 @@ class SpecifyShape(COp):
401401
_output_type_depends_on_input_value = True
402402

403403
def make_node(self, x, *shape):
404-
from pytensor.tensor.basic import get_underlying_scalar_constant_value
405-
406404
x = ptb.as_tensor_variable(x)
407405

408406
shape = tuple(
@@ -428,11 +426,9 @@ def make_node(self, x, *shape):
428426
for i, (xts, s) in enumerate(zip(x.type.shape, shape, strict=True)):
429427
if xts is not None:
430428
type_shape[i] = xts
431-
else:
429+
elif not isinstance(s.type, NoneTypeT):
432430
try:
433-
type_s = get_underlying_scalar_constant_value(s)
434-
if type_s is not None:
435-
type_shape[i] = int(type_s)
431+
type_shape[i] = int(ptb.get_underlying_scalar_constant_value(s))
436432
except NotScalarConstantError:
437433
pass
438434

@@ -460,22 +456,13 @@ def perform(self, node, inp, out_):
460456
def infer_shape(self, fgraph, node, shapes):
461457
xshape, *_ = shapes
462458
shape = node.inputs[1:]
463-
new_shape = []
464-
for dim in range(node.inputs[0].type.ndim):
465-
s = shape[dim]
466-
try:
467-
s = ptb.get_underlying_scalar_constant_value(s)
468-
# We assume that `None` shapes are always retrieved by
469-
# `get_underlying_scalar_constant_value`, and only in that case do we default to
470-
# the shape of the input variable
471-
if s is None:
472-
s = xshape[dim]
473-
except NotScalarConstantError:
474-
pass
475-
new_shape.append(ptb.as_tensor_variable(s))
476-
477-
assert len(new_shape) == len(xshape)
478-
return [new_shape]
459+
# Use x shape if specified dim is None, otherwise the specified shape
460+
return [
461+
[
462+
xshape[i] if isinstance(dim.type, NoneTypeT) else dim
463+
for i, dim in enumerate(shape)
464+
]
465+
]
479466

480467
def connection_pattern(self, node):
481468
return [[True], *[[False]] * len(node.inputs[1:])]

pytensor/tensor/variable.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Constant, OptionalApplyType, Variable
1313
from pytensor.graph.utils import MetaType
14-
from pytensor.scalar import ComplexError, IntegerDivisionError
14+
from pytensor.scalar import (
15+
ComplexError,
16+
IntegerDivisionError,
17+
)
1518
from pytensor.tensor import _get_vector_length
1619
from pytensor.tensor.exceptions import AdvancedIndexingError
1720
from pytensor.tensor.type import TensorType
@@ -1042,17 +1045,9 @@ def no_nan(self):
10421045

10431046
def get_unique_constant_value(x: TensorVariable) -> Number | None:
10441047
"""Return the unique value of a tensor, if there is one"""
1045-
if isinstance(x, Constant):
1046-
data = x.data
1047-
1048-
if isinstance(data, np.ndarray) and data.size > 0:
1049-
if data.size == 1:
1050-
return data.squeeze()
1051-
1052-
flat_data = data.ravel()
1053-
if (flat_data == flat_data[0]).all():
1054-
return flat_data[0]
1055-
1048+
warnings.warn("get_unique_constant_value is deprecated.", FutureWarning)
1049+
if isinstance(x, TensorConstant):
1050+
return x.unique_value
10561051
return None
10571052

10581053

@@ -1081,6 +1076,30 @@ def __init__(self, type: _TensorTypeType, data, name=None):
10811076
def signature(self):
10821077
return TensorConstantSignature((self.type, self.data))
10831078

1079+
@property
1080+
def unique_value(self) -> Number | None:
1081+
"""Return the unique value of a tensor, if there is one"""
1082+
try:
1083+
return self._unique_value
1084+
except AttributeError:
1085+
data = self.data
1086+
unique_value = None
1087+
if data.size > 0:
1088+
if data.size == 1:
1089+
unique_value = data.squeeze()
1090+
else:
1091+
flat_data = data.ravel()
1092+
if (flat_data == flat_data[0]).all():
1093+
unique_value = flat_data[0]
1094+
1095+
if unique_value is not None:
1096+
# Don't allow the unique value to be changed
1097+
unique_value.setflags(write=False)
1098+
1099+
self._unique_value = unique_value
1100+
1101+
return self._unique_value
1102+
10841103
def equals(self, other):
10851104
# Override Constant.equals to allow to compare with
10861105
# numpy.ndarray, and python type.

tests/tensor/test_basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3571,11 +3571,12 @@ def test_second(self):
35713571
assert get_underlying_scalar_constant_value(s) == c.data
35723572

35733573
def test_copy(self):
3574-
# Make sure we do not return the internal storage of a constant,
3574+
# Make sure we do not return a writeable internal storage of a constant,
35753575
# so we cannot change the value of a constant by mistake.
35763576
c = constant(3)
35773577
d = extract_constant(c)
3578-
d += 1
3578+
with pytest.raises(ValueError, match="output array is read-only"):
3579+
d += 1
35793580
e = extract_constant(c)
35803581
assert e == 3, (c, d, e)
35813582

0 commit comments

Comments
 (0)