|
30 | 30 | from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
|
31 | 31 | from pytensor.tensor.basic import (
|
32 | 32 | MakeVector,
|
33 |
| - alloc, |
34 |
| - cast, |
35 | 33 | constant,
|
36 |
| - get_underlying_scalar_constant_value, |
37 | 34 | )
|
38 | 35 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
39 |
| -from pytensor.tensor.exceptions import NotScalarConstantError |
40 | 36 | from pytensor.tensor.math import add, exp, mul
|
41 | 37 | from pytensor.tensor.rewriting.basic import (
|
42 | 38 | alloc_like,
|
43 | 39 | broadcasted_by,
|
44 | 40 | register_canonicalize,
|
45 | 41 | register_specialize,
|
46 | 42 | )
|
47 |
| -from pytensor.tensor.shape import shape_padleft |
48 | 43 | from pytensor.tensor.variable import TensorConstant, TensorVariable
|
49 | 44 |
|
50 | 45 |
|
@@ -434,66 +429,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
|
434 | 429 |
|
435 | 430 | """
|
436 | 431 | if len(node.outputs) > 1:
|
437 |
| - return |
438 |
| - try: |
439 |
| - shape_i = fgraph.shape_feature.shape_i |
440 |
| - except AttributeError: |
441 |
| - shape_i = None |
442 |
| - if isinstance(node.op, Elemwise): |
443 |
| - scalar_op = node.op.scalar_op |
444 |
| - # print "aa", scalar_op.output_types_preference |
445 |
| - if getattr(scalar_op, "output_types_preference", None) in ( |
446 |
| - ps.upgrade_to_float, |
447 |
| - ps.upcast_out, |
448 |
| - ): |
449 |
| - # this is the kind of op that we can screw with the input |
450 |
| - # dtypes by upcasting explicitly |
451 |
| - output_dtype = node.outputs[0].type.dtype |
452 |
| - new_inputs = [] |
453 |
| - for i in node.inputs: |
454 |
| - if i.type.dtype == output_dtype: |
455 |
| - new_inputs.append(i) |
456 |
| - else: |
457 |
| - try: |
458 |
| - cval_i = get_underlying_scalar_constant_value( |
459 |
| - i, only_process_constants=True |
460 |
| - ) |
461 |
| - if all(i.broadcastable): |
462 |
| - new_inputs.append( |
463 |
| - shape_padleft(cast(cval_i, output_dtype), i.ndim) |
464 |
| - ) |
465 |
| - else: |
466 |
| - if shape_i is None: |
467 |
| - return |
468 |
| - new_inputs.append( |
469 |
| - alloc( |
470 |
| - cast(cval_i, output_dtype), |
471 |
| - *[shape_i(d)(i) for d in range(i.ndim)], |
472 |
| - ) |
473 |
| - ) |
474 |
| - # print >> sys.stderr, "AAA", |
475 |
| - # *[Shape_i(d)(i) for d in range(i.ndim)] |
476 |
| - except NotScalarConstantError: |
477 |
| - # for the case of a non-scalar |
478 |
| - if isinstance(i, TensorConstant): |
479 |
| - new_inputs.append(cast(i, output_dtype)) |
480 |
| - else: |
481 |
| - new_inputs.append(i) |
| 432 | + return None |
| 433 | + |
| 434 | + if getattr(node.op.scalar_op, "output_types_preference", None) not in ( |
| 435 | + ps.upgrade_to_float, |
| 436 | + ps.upcast_out, |
| 437 | + ): |
| 438 | + return None |
482 | 439 |
|
483 |
| - if new_inputs != node.inputs: |
484 |
| - rval = [node.op(*new_inputs)] |
485 |
| - if not node.outputs[0].type.is_super(rval[0].type): |
486 |
| - # This can happen for example when floatX=float32 |
487 |
| - # and we do the true division between and int64 |
488 |
| - # and a constant that will get typed as int8. |
| 440 | + # this is the kind of op that we can screw with the input |
| 441 | + # dtypes by upcasting explicitly |
| 442 | + [old_out] = node.outputs |
| 443 | + output_dtype = old_out.type.dtype |
| 444 | + new_inputs = list(node.inputs) |
| 445 | + changed = False |
| 446 | + for i, inp in enumerate(node.inputs): |
| 447 | + if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant): |
| 448 | + new_inputs[i] = constant(inp.data.astype(output_dtype)) |
| 449 | + changed = True |
| 450 | + |
| 451 | + if not changed: |
| 452 | + return None |
489 | 453 |
|
490 |
| - # As this is just to allow merging more case, if |
491 |
| - # the upcast don't work, we can just skip it. |
492 |
| - return |
| 454 | + rval = node.op(*new_inputs) |
| 455 | + if not old_out.type.is_super(rval.type): |
| 456 | + # This can happen for example when floatX=float32 |
| 457 | + # and we do the true division between and int64 |
| 458 | + # and a constant that will get typed as int8. |
| 459 | + # As this is just to allow merging more case, if |
| 460 | + # the upcast don't work, we can just skip it. |
| 461 | + return None |
493 | 462 |
|
494 |
| - # Copy over output stacktrace from before upcasting |
495 |
| - copy_stack_trace(node.outputs[0], rval) |
496 |
| - return rval |
| 463 | + # Copy over output stacktrace from before upcasting |
| 464 | + copy_stack_trace(old_out, rval) |
| 465 | + return [rval] |
497 | 466 |
|
498 | 467 |
|
499 | 468 | @node_rewriter([add, mul])
|
|
0 commit comments