File tree Expand file tree Collapse file tree 1 file changed +17
-8
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 1 file changed +17
-8
lines changed Original file line number Diff line number Diff line change @@ -1913,20 +1913,29 @@ def local_pow_canonicalize(fgraph, node):
19131913
19141914 In all cases, the shape of the output is the result of broadcasting the shapes of the inputs.
19151915 """
1916-
19171916 cst_base = get_underlying_scalar_constant_value (
19181917 node .inputs [0 ], only_process_constants = True , raise_not_constant = False
19191918 )
1920- if cst_base == 1 :
1921- return [broadcast_arrays (* node .inputs )[0 ].astype (node .outputs [0 ].dtype )]
1922-
19231919 cst_exponent = get_underlying_scalar_constant_value (
19241920 node .inputs [1 ], only_process_constants = True , raise_not_constant = False
19251921 )
1926- if cst_exponent == 0 :
1927- return [alloc_like (1 , node .outputs [0 ], fgraph )]
1928- if cst_exponent == 1 :
1929- return [alloc_like (node .inputs [0 ], node .outputs [0 ], fgraph )]
1922+
1923+ new_out = None
1924+
1925+ if cst_base == 1 :
1926+ new_out = broadcast_arrays (* node .inputs )[0 ]
1927+ elif cst_exponent == 0 :
1928+ new_out = broadcast_arrays (ones_like (node .inputs [0 ]), node .inputs [1 ])[0 ]
1929+ elif cst_exponent == 1 :
1930+ new_out = broadcast_arrays (* node .inputs )[0 ]
1931+
1932+ if not new_out :
1933+ return
1934+
1935+ if new_out .dtype != node .out .dtype :
1936+ new_out = cast (new_out , dtype = node .out .dtype )
1937+
1938+ return [new_out ]
19301939
19311940
19321941@register_specialize
You can’t perform that action at this time.
0 commit comments