Skip to content

Commit 2266bd8

Browse files
committed
Don't do symbolic upcasting in local_upcast_elemwise_constants
This reduces the number of rewrite passes, by avoiding constant fold of cast/expand_dims/alloc
1 parent 5e64f0a commit 2266bd8

File tree

1 file changed

+31
-62
lines changed

1 file changed

+31
-62
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 31 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,16 @@
2828
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
2929
from pytensor.tensor.basic import (
3030
MakeVector,
31-
alloc,
32-
cast,
3331
constant,
34-
get_underlying_scalar_constant_value,
3532
)
3633
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
37-
from pytensor.tensor.exceptions import NotScalarConstantError
3834
from pytensor.tensor.math import add, exp, mul
3935
from pytensor.tensor.rewriting.basic import (
4036
alloc_like,
4137
broadcasted_by,
4238
register_canonicalize,
4339
register_specialize,
4440
)
45-
from pytensor.tensor.shape import shape_padleft
4641
from pytensor.tensor.variable import TensorConstant
4742

4843

@@ -483,66 +478,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
483478
484479
"""
485480
if len(node.outputs) > 1:
486-
return
487-
try:
488-
shape_i = fgraph.shape_feature.shape_i
489-
except AttributeError:
490-
shape_i = None
491-
if isinstance(node.op, Elemwise):
492-
scalar_op = node.op.scalar_op
493-
# print "aa", scalar_op.output_types_preference
494-
if getattr(scalar_op, "output_types_preference", None) in (
495-
ps.upgrade_to_float,
496-
ps.upcast_out,
497-
):
498-
# this is the kind of op that we can screw with the input
499-
# dtypes by upcasting explicitly
500-
output_dtype = node.outputs[0].type.dtype
501-
new_inputs = []
502-
for i in node.inputs:
503-
if i.type.dtype == output_dtype:
504-
new_inputs.append(i)
505-
else:
506-
try:
507-
cval_i = get_underlying_scalar_constant_value(
508-
i, only_process_constants=True
509-
)
510-
if all(i.broadcastable):
511-
new_inputs.append(
512-
shape_padleft(cast(cval_i, output_dtype), i.ndim)
513-
)
514-
else:
515-
if shape_i is None:
516-
return
517-
new_inputs.append(
518-
alloc(
519-
cast(cval_i, output_dtype),
520-
*[shape_i(d)(i) for d in range(i.ndim)],
521-
)
522-
)
523-
# print >> sys.stderr, "AAA",
524-
# *[Shape_i(d)(i) for d in range(i.ndim)]
525-
except NotScalarConstantError:
526-
# for the case of a non-scalar
527-
if isinstance(i, TensorConstant):
528-
new_inputs.append(cast(i, output_dtype))
529-
else:
530-
new_inputs.append(i)
481+
return None
531482

532-
if new_inputs != node.inputs:
533-
rval = [node.op(*new_inputs)]
534-
if not node.outputs[0].type.is_super(rval[0].type):
535-
# This can happen for example when floatX=float32
536-
# and we do the true division between and int64
537-
# and a constant that will get typed as int8.
483+
if getattr(node.op.scalar_op, "output_types_preference", None) not in (
484+
ps.upgrade_to_float,
485+
ps.upcast_out,
486+
):
487+
return None
538488

539-
# As this is just to allow merging more case, if
540-
# the upcast don't work, we can just skip it.
541-
return
489+
# this is the kind of op that we can screw with the input
490+
# dtypes by upcasting explicitly
491+
[old_out] = node.outputs
492+
output_dtype = old_out.type.dtype
493+
new_inputs = list(node.inputs)
494+
changed = False
495+
for i, inp in enumerate(node.inputs):
496+
if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant):
497+
new_inputs[i] = constant(inp.data.astype(output_dtype))
498+
changed = True
499+
500+
if not changed:
501+
return None
502+
503+
rval = node.op(*new_inputs)
504+
if not old_out.type.is_super(rval.type):
505+
# This can happen for example when floatX=float32
506+
# and we do the true division between and int64
507+
# and a constant that will get typed as int8.
508+
# As this is just to allow merging more case, if
509+
# the upcast don't work, we can just skip it.
510+
return None
542511

543-
# Copy over output stacktrace from before upcasting
544-
copy_stack_trace(node.outputs[0], rval)
545-
return rval
512+
# Copy over output stacktrace from before upcasting
513+
copy_stack_trace(old_out, rval)
514+
return [rval]
546515

547516

548517
@node_rewriter([add, mul])

0 commit comments

Comments
 (0)