@@ -126,24 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
126126 return consts , origconsts , nonconsts
127127
128128
129- def get_constant (v ):
130- """
131-
132- Returns
133- -------
134- object
135- A numeric constant if v is a Constant or, well, a
136- numeric constant. If v is a plain Variable, returns None.
137-
138- """
139- if isinstance (v , TensorConstant ):
140- return v .unique_value
141- elif isinstance (v , Variable ):
142- return None
143- else :
144- return v
145-
146-
147129@register_canonicalize
148130@register_stabilize
149131@node_rewriter ([Dot ])
@@ -980,8 +962,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
980962 """
981963 Find all constants and put them together into a single constant.
982964
983- Finds all constants in orig_num and orig_denum (using
984- get_constant) and puts them together into a single
965+ Finds all constants in orig_num and orig_denum
966+ and puts them together into a single
985967 constant. The constant is inserted as the first element of the
986968 numerator. If the constant is the neutral element, it is
987969 removed from the numerator.
@@ -1002,17 +984,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
1002984 numct , denumct = [], []
1003985
1004986 for v in orig_num :
1005- ct = get_constant (v )
1006- if ct is not None :
987+ if isinstance (v , TensorConstant ) and v .unique_value is not None :
1007988 # We found a constant in the numerator!
1008989 # We add it to numct
1009- numct .append (ct )
990+ numct .append (v . unique_value )
1010991 else :
1011992 num .append (v )
1012993 for v in orig_denum :
1013- ct = get_constant (v )
1014- if ct is not None :
1015- denumct .append (ct )
994+ if isinstance (v , TensorConstant ) and v .unique_value is not None :
995+ denumct .append (v .unique_value )
1016996 else :
1017997 denum .append (v )
1018998
@@ -1036,11 +1016,13 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
10361016
10371017 if orig_num and len (numct ) == 1 and len (denumct ) == 0 and ct :
10381018 # In that case we should only have one constant in `ct`.
1039- assert len (ct ) == 1
1040- first_num_ct = get_constant (orig_num [0 ])
1041- if first_num_ct is not None and ct [0 ].type .values_eq (
1042- ct [0 ].data , first_num_ct
1043- ):
1019+ [var_ct ] = ct
1020+
1021+ num_ct = None
1022+ if isinstance (var_ct , TensorConstant ):
1023+ num_ct = var_ct .unique_value
1024+
1025+ if num_ct is not None and var_ct .type .values_eq (var_ct .data , num_ct ):
10441026 # This is an important trick :( if it so happens that:
10451027 # * there's exactly one constant on the numerator and none on
10461028 # the denominator
@@ -1825,9 +1807,12 @@ def local_add_neg_to_sub(fgraph, node):
18251807 return [new_out ]
18261808
18271809 # Check if it is a negative constant
1828- const = get_constant (second )
1829- if const is not None and const < 0 :
1830- new_out = sub (first , np .abs (const ))
1810+ if (
1811+ isinstance (second , TensorConstant )
1812+ and second .unique_value is not None
1813+ and second .unique_value < 0
1814+ ):
1815+ new_out = sub (first , np .abs (second .data ))
18311816 return [new_out ]
18321817
18331818
@@ -1856,7 +1841,12 @@ def local_mul_zero(fgraph, node):
18561841@register_specialize
18571842@node_rewriter ([true_div ])
18581843def local_div_to_reciprocal (fgraph , node ):
1859- if np .all (get_constant (node .inputs [0 ]) == 1.0 ):
1844+ if (
1845+ get_underlying_scalar_constant_value (
1846+ node .inputs [0 ], only_process_constants = True , raise_not_constant = False
1847+ )
1848+ == 1.0
1849+ ):
18601850 out = node .outputs [0 ]
18611851 new_out = reciprocal (local_mul_canonizer .merge_num_denum (node .inputs [1 :], []))
18621852 # The ones could have forced upcasting
@@ -1877,7 +1867,9 @@ def local_reciprocal_canon(fgraph, node):
18771867@register_canonicalize
18781868@node_rewriter ([pt_pow ])
18791869def local_pow_canonicalize (fgraph , node ):
1880- cst = get_constant (node .inputs [1 ])
1870+ cst = get_underlying_scalar_constant_value (
1871+ node .inputs [1 ], only_process_constants = True , raise_not_constant = False
1872+ )
18811873 if cst == 0 :
18821874 return [alloc_like (1 , node .outputs [0 ], fgraph )]
18831875 if cst == 1 :
@@ -1908,7 +1900,12 @@ def local_intdiv_by_one(fgraph, node):
19081900@node_rewriter ([int_div , true_div ])
19091901def local_zero_div (fgraph , node ):
19101902 """0 / x -> 0"""
1911- if get_constant (node .inputs [0 ]) == 0 :
1903+ if (
1904+ get_underlying_scalar_constant_value (
1905+ node .inputs [0 ], only_process_constants = True , raise_not_constant = False
1906+ )
1907+ == 0
1908+ ):
19121909 ret = alloc_like (0 , node .outputs [0 ], fgraph )
19131910 ret .tag .values_eq_approx = values_eq_approx_remove_nan
19141911 return [ret ]
@@ -1921,8 +1918,12 @@ def local_pow_specialize(fgraph, node):
19211918 odtype = node .outputs [0 ].dtype
19221919 xsym = node .inputs [0 ]
19231920 ysym = node .inputs [1 ]
1924- y = get_constant (ysym )
1925- if (y is not None ) and not broadcasted_by (xsym , ysym ):
1921+ try :
1922+ y = get_underlying_scalar_constant_value (ysym , only_process_constants = True )
1923+ except NotScalarConstantError :
1924+ return
1925+
1926+ if not broadcasted_by (xsym , ysym ):
19261927 rval = None
19271928
19281929 if np .all (y == 2 ):
@@ -1956,10 +1957,14 @@ def local_pow_to_nested_squaring(fgraph, node):
19561957 """
19571958
19581959 # the idea here is that we have pow(x, y)
1960+ xsym , ysym = node .inputs
1961+
1962+ try :
1963+ y = get_underlying_scalar_constant_value (ysym , only_process_constants = True )
1964+ except NotScalarConstantError :
1965+ return
1966+
19591967 odtype = node .outputs [0 ].dtype
1960- xsym = node .inputs [0 ]
1961- ysym = node .inputs [1 ]
1962- y = get_constant (ysym )
19631968
19641969 # the next line is needed to fix a strange case that I don't
19651970 # know how to make a separate test.
@@ -1975,7 +1980,7 @@ def local_pow_to_nested_squaring(fgraph, node):
19751980 y = y [0 ]
19761981 except IndexError :
19771982 pass
1978- if ( y is not None ) and not broadcasted_by (xsym , ysym ):
1983+ if not broadcasted_by (xsym , ysym ):
19791984 rval = None
19801985 # 512 is too small for the cpu and too big for some gpu!
19811986 if abs (y ) == int (abs (y )) and abs (y ) <= 512 :
@@ -2042,7 +2047,9 @@ def local_mul_specialize(fgraph, node):
20422047 nb_neg_node += 1
20432048
20442049 # remove special case arguments of 1, -1 or 0
2045- y = get_constant (inp )
2050+ y = get_underlying_scalar_constant_value (
2051+ inp , only_process_constants = True , raise_not_constant = False
2052+ )
20462053 if y == 1.0 :
20472054 nb_cst += 1
20482055 elif y == - 1.0 :
0 commit comments