@@ -126,24 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
126126 return consts , origconsts , nonconsts
127127
128128
129- def get_constant (v ):
130- """
131-
132- Returns
133- -------
134- object
135- A numeric constant if v is a Constant or, well, a
136- numeric constant. If v is a plain Variable, returns None.
137-
138- """
139- if isinstance (v , TensorConstant ):
140- return v .unique_value
141- elif isinstance (v , Variable ):
142- return None
143- else :
144- return v
145-
146-
147129@register_canonicalize
148130@register_stabilize
149131@node_rewriter ([Dot ])
@@ -994,8 +976,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
994976 """
995977 Find all constants and put them together into a single constant.
996978
997- Finds all constants in orig_num and orig_denum (using
998- get_constant) and puts them together into a single
979+ Finds all constants in orig_num and orig_denum
980+ and puts them together into a single
999981 constant. The constant is inserted as the first element of the
1000982 numerator. If the constant is the neutral element, it is
1001983 removed from the numerator.
@@ -1016,17 +998,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
1016998 numct , denumct = [], []
1017999
10181000 for v in orig_num :
1019- ct = get_constant (v )
1020- if ct is not None :
1001+ if isinstance (v , TensorConstant ) and v .unique_value is not None :
10211002 # We found a constant in the numerator!
10221003 # We add it to numct
1023- numct .append (ct )
1004+ numct .append (v . unique_value )
10241005 else :
10251006 num .append (v )
10261007 for v in orig_denum :
1027- ct = get_constant (v )
1028- if ct is not None :
1029- denumct .append (ct )
1008+ if isinstance (v , TensorConstant ) and v .unique_value is not None :
1009+ denumct .append (v .unique_value )
10301010 else :
10311011 denum .append (v )
10321012
@@ -1050,10 +1030,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
10501030
10511031 if orig_num and len (numct ) == 1 and len (denumct ) == 0 and ct :
10521032 # In that case we should only have one constant in `ct`.
1053- assert len (ct ) == 1
1054- first_num_ct = get_constant (orig_num [0 ])
1055- if first_num_ct is not None and ct [0 ].type .values_eq (
1056- ct [0 ].data , first_num_ct
1033+ [var_ct ] = ct
1034+ first_num_var = orig_num [0 ]
1035+ first_num_ct = (
1036+ first_num_var .unique_value
1037+ if isinstance (first_num_var , TensorConstant )
1038+ else None
1039+ )
1040+ if first_num_ct is not None and var_ct .type .values_eq (
1041+ var_ct .data , first_num_ct
10571042 ):
10581043 # This is an important trick :( if it so happens that:
10591044 # * there's exactly one constant on the numerator and none on
@@ -1840,9 +1825,12 @@ def local_add_neg_to_sub(fgraph, node):
18401825 return [new_out ]
18411826
18421827 # Check if it is a negative constant
1843- const = get_constant (second )
1844- if const is not None and const < 0 :
1845- new_out = sub (first , np .abs (const ))
1828+ if (
1829+ isinstance (second , TensorConstant )
1830+ and second .unique_value is not None
1831+ and second .unique_value < 0
1832+ ):
1833+ new_out = sub (first , np .abs (second .data ))
18461834 return [new_out ]
18471835
18481836
@@ -1871,7 +1859,12 @@ def local_mul_zero(fgraph, node):
18711859@register_specialize
18721860@node_rewriter ([true_div ])
18731861def local_div_to_reciprocal (fgraph , node ):
1874- if np .all (get_constant (node .inputs [0 ]) == 1.0 ):
1862+ if (
1863+ get_underlying_scalar_constant_value (
1864+ node .inputs [0 ], only_process_constants = True , raise_not_constant = False
1865+ )
1866+ == 1.0
1867+ ):
18751868 out = node .outputs [0 ]
18761869 new_out = reciprocal (local_mul_canonizer .merge_num_denum (node .inputs [1 :], []))
18771870 # The ones could have forced upcasting
@@ -1892,7 +1885,9 @@ def local_reciprocal_canon(fgraph, node):
18921885@register_canonicalize
18931886@node_rewriter ([pt_pow ])
18941887def local_pow_canonicalize (fgraph , node ):
1895- cst = get_constant (node .inputs [1 ])
1888+ cst = get_underlying_scalar_constant_value (
1889+ node .inputs [1 ], only_process_constants = True , raise_not_constant = False
1890+ )
18961891 if cst == 0 :
18971892 return [alloc_like (1 , node .outputs [0 ], fgraph )]
18981893 if cst == 1 :
@@ -1923,7 +1918,12 @@ def local_intdiv_by_one(fgraph, node):
19231918@node_rewriter ([int_div , true_div ])
19241919def local_zero_div (fgraph , node ):
19251920 """0 / x -> 0"""
1926- if get_constant (node .inputs [0 ]) == 0 :
1921+ if (
1922+ get_underlying_scalar_constant_value (
1923+ node .inputs [0 ], only_process_constants = True , raise_not_constant = False
1924+ )
1925+ == 0
1926+ ):
19271927 ret = alloc_like (0 , node .outputs [0 ], fgraph )
19281928 ret .tag .values_eq_approx = values_eq_approx_remove_nan
19291929 return [ret ]
@@ -1936,8 +1936,12 @@ def local_pow_specialize(fgraph, node):
19361936 odtype = node .outputs [0 ].dtype
19371937 xsym = node .inputs [0 ]
19381938 ysym = node .inputs [1 ]
1939- y = get_constant (ysym )
1940- if (y is not None ) and not broadcasted_by (xsym , ysym ):
1939+ try :
1940+ y = get_underlying_scalar_constant_value (ysym , only_process_constants = True )
1941+ except NotScalarConstantError :
1942+ return
1943+
1944+ if not broadcasted_by (xsym , ysym ):
19411945 rval = None
19421946
19431947 if np .all (y == 2 ):
@@ -1971,10 +1975,14 @@ def local_pow_to_nested_squaring(fgraph, node):
19711975 """
19721976
19731977 # the idea here is that we have pow(x, y)
1978+ xsym , ysym = node .inputs
1979+
1980+ try :
1981+ y = get_underlying_scalar_constant_value (ysym , only_process_constants = True )
1982+ except NotScalarConstantError :
1983+ return
1984+
19741985 odtype = node .outputs [0 ].dtype
1975- xsym = node .inputs [0 ]
1976- ysym = node .inputs [1 ]
1977- y = get_constant (ysym )
19781986
19791987 # the next line is needed to fix a strange case that I don't
19801988 # know how to make a separate test.
@@ -1990,7 +1998,7 @@ def local_pow_to_nested_squaring(fgraph, node):
19901998 y = y [0 ]
19911999 except IndexError :
19922000 pass
1993- if ( y is not None ) and not broadcasted_by (xsym , ysym ):
2001+ if not broadcasted_by (xsym , ysym ):
19942002 rval = None
19952003 # 512 is too small for the cpu and too big for some gpu!
19962004 if abs (y ) == int (abs (y )) and abs (y ) <= 512 :
@@ -2057,7 +2065,9 @@ def local_mul_specialize(fgraph, node):
20572065 nb_neg_node += 1
20582066
20592067 # remove special case arguments of 1, -1 or 0
2060- y = get_constant (inp )
2068+ y = get_underlying_scalar_constant_value (
2069+ inp , only_process_constants = True , raise_not_constant = False
2070+ )
20612071 if y == 1.0 :
20622072 nb_cst += 1
20632073 elif y == - 1.0 :
0 commit comments