Skip to content

Commit 41564b8

Browse files
committed
Don't do symbolic upcasting in local_upcast_elemwise_constants
This reduces the number of rewrite passes, by avoiding constant fold of cast/expand_dims/alloc
1 parent b9fc4f8 commit 41564b8

File tree

1 file changed

+31
-62
lines changed

1 file changed

+31
-62
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 31 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,16 @@
3030
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
3131
from pytensor.tensor.basic import (
3232
MakeVector,
33-
alloc,
34-
cast,
3533
constant,
36-
get_underlying_scalar_constant_value,
3734
)
3835
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
39-
from pytensor.tensor.exceptions import NotScalarConstantError
4036
from pytensor.tensor.math import add, exp, mul
4137
from pytensor.tensor.rewriting.basic import (
4238
alloc_like,
4339
broadcasted_by,
4440
register_canonicalize,
4541
register_specialize,
4642
)
47-
from pytensor.tensor.shape import shape_padleft
4843
from pytensor.tensor.variable import TensorConstant, TensorVariable
4944

5045

@@ -434,66 +429,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
434429
435430
"""
436431
if len(node.outputs) > 1:
437-
return
438-
try:
439-
shape_i = fgraph.shape_feature.shape_i
440-
except AttributeError:
441-
shape_i = None
442-
if isinstance(node.op, Elemwise):
443-
scalar_op = node.op.scalar_op
444-
# print "aa", scalar_op.output_types_preference
445-
if getattr(scalar_op, "output_types_preference", None) in (
446-
ps.upgrade_to_float,
447-
ps.upcast_out,
448-
):
449-
# this is the kind of op that we can screw with the input
450-
# dtypes by upcasting explicitly
451-
output_dtype = node.outputs[0].type.dtype
452-
new_inputs = []
453-
for i in node.inputs:
454-
if i.type.dtype == output_dtype:
455-
new_inputs.append(i)
456-
else:
457-
try:
458-
cval_i = get_underlying_scalar_constant_value(
459-
i, only_process_constants=True
460-
)
461-
if all(i.broadcastable):
462-
new_inputs.append(
463-
shape_padleft(cast(cval_i, output_dtype), i.ndim)
464-
)
465-
else:
466-
if shape_i is None:
467-
return
468-
new_inputs.append(
469-
alloc(
470-
cast(cval_i, output_dtype),
471-
*[shape_i(d)(i) for d in range(i.ndim)],
472-
)
473-
)
474-
# print >> sys.stderr, "AAA",
475-
# *[Shape_i(d)(i) for d in range(i.ndim)]
476-
except NotScalarConstantError:
477-
# for the case of a non-scalar
478-
if isinstance(i, TensorConstant):
479-
new_inputs.append(cast(i, output_dtype))
480-
else:
481-
new_inputs.append(i)
432+
return None
433+
434+
if getattr(node.op.scalar_op, "output_types_preference", None) not in (
435+
ps.upgrade_to_float,
436+
ps.upcast_out,
437+
):
438+
return None
482439

483-
if new_inputs != node.inputs:
484-
rval = [node.op(*new_inputs)]
485-
if not node.outputs[0].type.is_super(rval[0].type):
486-
# This can happen for example when floatX=float32
487-
# and we do the true division between and int64
488-
# and a constant that will get typed as int8.
440+
# this is the kind of op that we can screw with the input
441+
# dtypes by upcasting explicitly
442+
[old_out] = node.outputs
443+
output_dtype = old_out.type.dtype
444+
new_inputs = list(node.inputs)
445+
changed = False
446+
for i, inp in enumerate(node.inputs):
447+
if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant):
448+
new_inputs[i] = constant(inp.data.astype(output_dtype))
449+
changed = True
450+
451+
if not changed:
452+
return None
489453

490-
# As this is just to allow merging more case, if
491-
# the upcast don't work, we can just skip it.
492-
return
454+
rval = node.op(*new_inputs)
455+
if not old_out.type.is_super(rval.type):
456+
# This can happen for example when floatX=float32
457+
# and we do the true division between and int64
458+
# and a constant that will get typed as int8.
459+
# As this is just to allow merging more case, if
460+
# the upcast don't work, we can just skip it.
461+
return None
493462

494-
# Copy over output stacktrace from before upcasting
495-
copy_stack_trace(node.outputs[0], rval)
496-
return rval
463+
# Copy over output stacktrace from before upcasting
464+
copy_stack_trace(old_out, rval)
465+
return [rval]
497466

498467

499468
@node_rewriter([add, mul])

0 commit comments

Comments
 (0)