@@ -126,24 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
126126 return consts , origconsts , nonconsts
127127
128128
129- def get_constant (v ):
130- """
131-
132- Returns
133- -------
134- object
135- A numeric constant if v is a Constant or, well, a
136- numeric constant. If v is a plain Variable, returns None.
137-
138- """
139- if isinstance (v , TensorConstant ):
140- return v .unique_value
141- elif isinstance (v , Variable ):
142- return None
143- else :
144- return v
145-
146-
147129@register_canonicalize
148130@register_stabilize
149131@node_rewriter ([Dot ])
@@ -994,8 +976,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
994976 """
995977 Find all constants and put them together into a single constant.
996978
997- Finds all constants in orig_num and orig_denum (using
998- get_constant) and puts them together into a single
979+ Finds all constants in orig_num and orig_denum
980+ and puts them together into a single
999981 constant. The constant is inserted as the first element of the
1000982 numerator. If the constant is the neutral element, it is
1001983 removed from the numerator.
@@ -1016,17 +998,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
1016998 numct , denumct = [], []
1017999
10181000 for v in orig_num :
1019- ct = get_constant (v )
1020- if ct is not None :
1001+ if isinstance (v , TensorConstant ) and v .unique_value is not None :
10211002 # We found a constant in the numerator!
10221003 # We add it to numct
1023- numct .append (ct )
1004+ numct .append (v . unique_value )
10241005 else :
10251006 num .append (v )
10261007 for v in orig_denum :
1027- ct = get_constant (v )
1028- if ct is not None :
1029- denumct .append (ct )
1008+ if isinstance (v , TensorConstant ) and v .unique_value is not None :
1009+ denumct .append (v .unique_value )
10301010 else :
10311011 denum .append (v )
10321012
@@ -1050,11 +1030,13 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
10501030
10511031 if orig_num and len (numct ) == 1 and len (denumct ) == 0 and ct :
10521032 # In that case we should only have one constant in `ct`.
1053- assert len (ct ) == 1
1054- first_num_ct = get_constant (orig_num [0 ])
1055- if first_num_ct is not None and ct [0 ].type .values_eq (
1056- ct [0 ].data , first_num_ct
1057- ):
1033+ [var_ct ] = ct
1034+
1035+ num_ct = None
1036+ if isinstance (var_ct , TensorConstant ):
1037+ num_ct = var_ct .unique_value
1038+
1039+ if num_ct is not None and var_ct .type .values_eq (var_ct .data , num_ct ):
10581040 # This is an important trick :( if it so happens that:
10591041 # * there's exactly one constant on the numerator and none on
10601042 # the denominator
@@ -1839,9 +1821,12 @@ def local_add_neg_to_sub(fgraph, node):
18391821 return [new_out ]
18401822
18411823 # Check if it is a negative constant
1842- const = get_constant (second )
1843- if const is not None and const < 0 :
1844- new_out = sub (first , np .abs (const ))
1824+ if (
1825+ isinstance (second , TensorConstant )
1826+ and second .unique_value is not None
1827+ and second .unique_value < 0
1828+ ):
1829+ new_out = sub (first , np .abs (second .data ))
18451830 return [new_out ]
18461831
18471832
@@ -1870,7 +1855,12 @@ def local_mul_zero(fgraph, node):
18701855@register_specialize
18711856@node_rewriter ([true_div ])
18721857def local_div_to_reciprocal (fgraph , node ):
1873- if np .all (get_constant (node .inputs [0 ]) == 1.0 ):
1858+ if (
1859+ get_underlying_scalar_constant_value (
1860+ node .inputs [0 ], only_process_constants = True , raise_not_constant = False
1861+ )
1862+ == 1.0
1863+ ):
18741864 out = node .outputs [0 ]
18751865 new_out = reciprocal (local_mul_canonizer .merge_num_denum (node .inputs [1 :], []))
18761866 # The ones could have forced upcasting
@@ -1891,7 +1881,9 @@ def local_reciprocal_canon(fgraph, node):
18911881@register_canonicalize
18921882@node_rewriter ([pt_pow ])
18931883def local_pow_canonicalize (fgraph , node ):
1894- cst = get_constant (node .inputs [1 ])
1884+ cst = get_underlying_scalar_constant_value (
1885+ node .inputs [1 ], only_process_constants = True , raise_not_constant = False
1886+ )
18951887 if cst == 0 :
18961888 return [alloc_like (1 , node .outputs [0 ], fgraph )]
18971889 if cst == 1 :
@@ -1922,7 +1914,12 @@ def local_intdiv_by_one(fgraph, node):
19221914@node_rewriter ([int_div , true_div ])
19231915def local_zero_div (fgraph , node ):
19241916 """0 / x -> 0"""
1925- if get_constant (node .inputs [0 ]) == 0 :
1917+ if (
1918+ get_underlying_scalar_constant_value (
1919+ node .inputs [0 ], only_process_constants = True , raise_not_constant = False
1920+ )
1921+ == 0
1922+ ):
19261923 ret = alloc_like (0 , node .outputs [0 ], fgraph )
19271924 ret .tag .values_eq_approx = values_eq_approx_remove_nan
19281925 return [ret ]
@@ -1935,8 +1932,12 @@ def local_pow_specialize(fgraph, node):
19351932 odtype = node .outputs [0 ].dtype
19361933 xsym = node .inputs [0 ]
19371934 ysym = node .inputs [1 ]
1938- y = get_constant (ysym )
1939- if (y is not None ) and not broadcasted_by (xsym , ysym ):
1935+ try :
1936+ y = get_underlying_scalar_constant_value (ysym , only_process_constants = True )
1937+ except NotScalarConstantError :
1938+ return
1939+
1940+ if not broadcasted_by (xsym , ysym ):
19401941 rval = None
19411942
19421943 if np .all (y == 2 ):
@@ -1970,10 +1971,14 @@ def local_pow_to_nested_squaring(fgraph, node):
19701971 """
19711972
19721973 # the idea here is that we have pow(x, y)
1974+ xsym , ysym = node .inputs
1975+
1976+ try :
1977+ y = get_underlying_scalar_constant_value (ysym , only_process_constants = True )
1978+ except NotScalarConstantError :
1979+ return
1980+
19731981 odtype = node .outputs [0 ].dtype
1974- xsym = node .inputs [0 ]
1975- ysym = node .inputs [1 ]
1976- y = get_constant (ysym )
19771982
19781983 # the next line is needed to fix a strange case that I don't
19791984 # know how to make a separate test.
@@ -1989,7 +1994,7 @@ def local_pow_to_nested_squaring(fgraph, node):
19891994 y = y [0 ]
19901995 except IndexError :
19911996 pass
1992- if ( y is not None ) and not broadcasted_by (xsym , ysym ):
1997+ if not broadcasted_by (xsym , ysym ):
19931998 rval = None
19941999 # 512 is too small for the cpu and too big for some gpu!
19952000 if abs (y ) == int (abs (y )) and abs (y ) <= 512 :
@@ -2056,7 +2061,9 @@ def local_mul_specialize(fgraph, node):
20562061 nb_neg_node += 1
20572062
20582063 # remove special case arguments of 1, -1 or 0
2059- y = get_constant (inp )
2064+ y = get_underlying_scalar_constant_value (
2065+ inp , only_process_constants = True , raise_not_constant = False
2066+ )
20602067 if y == 1.0 :
20612068 nb_cst += 1
20622069 elif y == - 1.0 :
0 commit comments