|
28 | 28 | from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
|
29 | 29 | from pytensor.tensor.basic import (
|
30 | 30 | MakeVector,
|
31 |
| - alloc, |
32 |
| - cast, |
33 | 31 | constant,
|
34 |
| - get_underlying_scalar_constant_value, |
35 | 32 | )
|
36 | 33 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
37 |
| -from pytensor.tensor.exceptions import NotScalarConstantError |
38 | 34 | from pytensor.tensor.math import add, exp, mul
|
39 | 35 | from pytensor.tensor.rewriting.basic import (
|
40 | 36 | alloc_like,
|
41 | 37 | broadcasted_by,
|
42 | 38 | register_canonicalize,
|
43 | 39 | register_specialize,
|
44 | 40 | )
|
45 |
| -from pytensor.tensor.shape import shape_padleft |
46 | 41 | from pytensor.tensor.variable import TensorConstant
|
47 | 42 |
|
48 | 43 |
|
@@ -483,66 +478,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
|
483 | 478 |
|
484 | 479 | """
|
485 | 480 | if len(node.outputs) > 1:
|
486 |
| - return |
487 |
| - try: |
488 |
| - shape_i = fgraph.shape_feature.shape_i |
489 |
| - except AttributeError: |
490 |
| - shape_i = None |
491 |
| - if isinstance(node.op, Elemwise): |
492 |
| - scalar_op = node.op.scalar_op |
493 |
| - # print "aa", scalar_op.output_types_preference |
494 |
| - if getattr(scalar_op, "output_types_preference", None) in ( |
495 |
| - ps.upgrade_to_float, |
496 |
| - ps.upcast_out, |
497 |
| - ): |
498 |
| - # this is the kind of op that we can screw with the input |
499 |
| - # dtypes by upcasting explicitly |
500 |
| - output_dtype = node.outputs[0].type.dtype |
501 |
| - new_inputs = [] |
502 |
| - for i in node.inputs: |
503 |
| - if i.type.dtype == output_dtype: |
504 |
| - new_inputs.append(i) |
505 |
| - else: |
506 |
| - try: |
507 |
| - cval_i = get_underlying_scalar_constant_value( |
508 |
| - i, only_process_constants=True |
509 |
| - ) |
510 |
| - if all(i.broadcastable): |
511 |
| - new_inputs.append( |
512 |
| - shape_padleft(cast(cval_i, output_dtype), i.ndim) |
513 |
| - ) |
514 |
| - else: |
515 |
| - if shape_i is None: |
516 |
| - return |
517 |
| - new_inputs.append( |
518 |
| - alloc( |
519 |
| - cast(cval_i, output_dtype), |
520 |
| - *[shape_i(d)(i) for d in range(i.ndim)], |
521 |
| - ) |
522 |
| - ) |
523 |
| - # print >> sys.stderr, "AAA", |
524 |
| - # *[Shape_i(d)(i) for d in range(i.ndim)] |
525 |
| - except NotScalarConstantError: |
526 |
| - # for the case of a non-scalar |
527 |
| - if isinstance(i, TensorConstant): |
528 |
| - new_inputs.append(cast(i, output_dtype)) |
529 |
| - else: |
530 |
| - new_inputs.append(i) |
| 481 | + return None |
531 | 482 |
|
532 |
| - if new_inputs != node.inputs: |
533 |
| - rval = [node.op(*new_inputs)] |
534 |
| - if not node.outputs[0].type.is_super(rval[0].type): |
535 |
| - # This can happen for example when floatX=float32 |
536 |
| - # and we do the true division between and int64 |
537 |
| - # and a constant that will get typed as int8. |
| 483 | + if getattr(node.op.scalar_op, "output_types_preference", None) not in ( |
| 484 | + ps.upgrade_to_float, |
| 485 | + ps.upcast_out, |
| 486 | + ): |
| 487 | + return None |
538 | 488 |
|
539 |
| - # As this is just to allow merging more case, if |
540 |
| - # the upcast don't work, we can just skip it. |
541 |
| - return |
| 489 | + # this is the kind of op that we can screw with the input |
| 490 | + # dtypes by upcasting explicitly |
| 491 | + [old_out] = node.outputs |
| 492 | + output_dtype = old_out.type.dtype |
| 493 | + new_inputs = list(node.inputs) |
| 494 | + changed = False |
| 495 | + for i, inp in enumerate(node.inputs): |
| 496 | + if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant): |
| 497 | + new_inputs[i] = constant(inp.data.astype(output_dtype)) |
| 498 | + changed = True |
| 499 | + |
| 500 | + if not changed: |
| 501 | + return None |
| 502 | + |
| 503 | + rval = node.op(*new_inputs) |
| 504 | + if not old_out.type.is_super(rval.type): |
| 505 | + # This can happen for example when floatX=float32 |
| 506 | + # and we do the true division between and int64 |
| 507 | + # and a constant that will get typed as int8. |
| 508 | + # As this is just to allow merging more case, if |
| 509 | + # the upcast don't work, we can just skip it. |
| 510 | + return None |
542 | 511 |
|
543 |
| - # Copy over output stacktrace from before upcasting |
544 |
| - copy_stack_trace(node.outputs[0], rval) |
545 |
| - return rval |
| 512 | + # Copy over output stacktrace from before upcasting |
| 513 | + copy_stack_trace(old_out, rval) |
| 514 | + return [rval] |
546 | 515 |
|
547 | 516 |
|
548 | 517 | @node_rewriter([add, mul])
|
|
0 commit comments