Skip to content

Commit 836eb45

Browse files
Refactor rewrite
1 parent 6df788c commit 836eb45

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,20 +1913,29 @@ def local_pow_canonicalize(fgraph, node):
19131913
19141914
In all cases, the shape of the output is the result of broadcasting the shapes of the inputs.
19151915
"""
1916-
19171916
cst_base = get_underlying_scalar_constant_value(
19181917
node.inputs[0], only_process_constants=True, raise_not_constant=False
19191918
)
1920-
if cst_base == 1:
1921-
return [broadcast_arrays(*node.inputs)[0].astype(node.outputs[0].dtype)]
1922-
19231919
cst_exponent = get_underlying_scalar_constant_value(
19241920
node.inputs[1], only_process_constants=True, raise_not_constant=False
19251921
)
1926-
if cst_exponent == 0:
1927-
return [alloc_like(1, node.outputs[0], fgraph)]
1928-
if cst_exponent == 1:
1929-
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
1922+
1923+
new_out = None
1924+
1925+
if cst_base == 1:
1926+
new_out = broadcast_arrays(*node.inputs)[0]
1927+
elif cst_exponent == 0:
1928+
new_out = broadcast_arrays(ones_like(node.inputs[0]), node.inputs[1])[0]
1929+
elif cst_exponent == 1:
1930+
new_out = broadcast_arrays(*node.inputs)[0]
1931+
1932+
if not new_out:
1933+
return
1934+
1935+
if new_out.dtype != node.out.dtype:
1936+
new_out = cast(new_out, dtype=node.out.dtype)
1937+
1938+
return [new_out]
19301939

19311940

19321941
@register_specialize

0 commit comments

Comments
 (0)