|
9 | 9 |
|
10 | 10 | import pytensor.scalar.basic as ps |
11 | 11 | import pytensor.scalar.math as ps_math |
12 | | -from pytensor.graph import FunctionGraph |
13 | | -from pytensor.graph.basic import Apply, Constant, Variable |
| 12 | +from pytensor.graph.basic import Constant, Variable |
14 | 13 | from pytensor.graph.rewriting.basic import ( |
15 | 14 | NodeRewriter, |
16 | 15 | PatternNodeRewriter, |
@@ -1906,39 +1905,27 @@ def local_reciprocal_canon(fgraph, node): |
1906 | 1905 | @register_canonicalize |
1907 | 1906 | @node_rewriter([pt_pow]) |
1908 | 1907 | def local_pow_canonicalize(fgraph, node): |
1909 | | - cst = get_underlying_scalar_constant_value( |
1910 | | - node.inputs[1], only_process_constants=True, raise_not_constant=False |
1911 | | - ) |
1912 | | - if cst == 0: |
1913 | | - return [alloc_like(1, node.outputs[0], fgraph)] |
1914 | | - if cst == 1: |
1915 | | - return [alloc_like(node.inputs[0], node.outputs[0], fgraph)] |
1916 | | - |
1917 | | - |
1918 | | -@register_canonicalize |
1919 | | -@node_rewriter([pt_pow]) |
1920 | | -def local_pow_canonicalize_base_1( |
1921 | | - fgraph: FunctionGraph, node: Apply |
1922 | | -) -> list[TensorVariable] | None: |
1923 | 1908 | """ |
1924 | | - Replace `1 ** x` with 1, broadcast to the shape of the output. |
| 1909 | + Rewrites for exponential functions with straight-forward simplifications: |
| 1910 | + 1. x ** 0 -> 1 |
| 1911 | + 2. x ** 1 -> x |
| 1912 | + 3. 1 ** x -> 1 |
1925 | 1913 |
|
1926 | | - Parameters |
1927 | | - ---------- |
1928 | | - fgraph: FunctionGraph |
1929 | | - Full function graph being rewritten |
1930 | | - node: Apply |
1931 | | - Specific node being rewritten |
1932 | | -
|
1933 | | - Returns |
1934 | | - ------- |
1935 | | - rewritten_output: list[TensorVariable] | None |
1936 | | - Rewritten output of node, or None if no rewrite is possible |
| 1914 | + In all cases, the shape of the output is the result of broadcasting the shapes of the inputs. |
1937 | 1915 | """ |
1938 | | - cst = get_underlying_scalar_constant_value( |
| 1916 | + |
| 1917 | + cst_base = get_underlying_scalar_constant_value( |
1939 | 1918 | node.inputs[0], only_process_constants=True, raise_not_constant=False |
1940 | 1919 | ) |
1941 | | - if cst == 1: |
| 1920 | + if cst_base == 1: |
| 1921 | + return [alloc_like(1, node.outputs[0], fgraph)] |
| 1922 | + |
| 1923 | + cst_exponent = get_underlying_scalar_constant_value( |
| 1924 | + node.inputs[1], only_process_constants=True, raise_not_constant=False |
| 1925 | + ) |
| 1926 | + if cst_exponent == 0: |
| 1927 | + return [alloc_like(1, node.outputs[0], fgraph)] |
| 1928 | + if cst_exponent == 1: |
1942 | 1929 | return [alloc_like(node.inputs[0], node.outputs[0], fgraph)] |
1943 | 1930 |
|
1944 | 1931 |
|
|
0 commit comments