@@ -125,24 +125,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
125125 return consts , origconsts , nonconsts
126126
127127
128- def get_constant (v ):
129- """
130-
131- Returns
132- -------
133- object
134- A numeric constant if v is a Constant or, well, a
135- numeric constant. If v is a plain Variable, returns None.
136-
137- """
138- if isinstance (v , TensorConstant ):
139- return v .unique_value
140- elif isinstance (v , Variable ):
141- return None
142- else :
143- return v
144-
145-
146128@register_canonicalize
147129@register_stabilize
148130@node_rewriter ([Dot ])
@@ -1021,8 +1003,8 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
10211003 """
10221004 Find all constants and put them together into a single constant.
10231005
1024- Finds all constants in orig_num and orig_denum (using
1025- get_constant) and puts them together into a single
1006+ Finds all constants in orig_num and orig_denum
1007+ and puts them together into a single
10261008 constant. The constant is inserted as the first element of the
10271009 numerator. If the constant is the neutral element, it is
10281010 removed from the numerator.
@@ -1043,17 +1025,15 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
10431025 numct , denumct = [], []
10441026
10451027 for v in orig_num :
1046- ct = get_constant (v )
1047- if ct is not None :
1028+ if isinstance (v , TensorConstant ) and v .unique_value is not None :
10481029 # We found a constant in the numerator!
10491030 # We add it to numct
1050- numct .append (ct )
1031+ numct .append (v . unique_value )
10511032 else :
10521033 num .append (v )
10531034 for v in orig_denum :
1054- ct = get_constant (v )
1055- if ct is not None :
1056- denumct .append (ct )
1035+ if isinstance (v , TensorConstant ) and v .unique_value is not None :
1036+ denumct .append (v .unique_value )
10571037 else :
10581038 denum .append (v )
10591039
@@ -1077,11 +1057,13 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
10771057
10781058 if orig_num and len (numct ) == 1 and len (denumct ) == 0 and ct :
10791059 # In that case we should only have one constant in `ct`.
1080- assert len (ct ) == 1
1081- first_num_ct = get_constant (orig_num [0 ])
1082- if first_num_ct is not None and ct [0 ].type .values_eq (
1083- ct [0 ].data , first_num_ct
1084- ):
1060+ [var_ct ] = ct
1061+
1062+ num_ct = None
1063+ if isinstance (var_ct , TensorConstant ):
1064+ num_ct = var_ct .unique_value
1065+
1066+ if num_ct is not None and var_ct .type .values_eq (var_ct .data , num_ct ):
10851067 # This is an important trick :( if it so happens that:
10861068 # * there's exactly one constant on the numerator and none on
10871069 # the denominator
@@ -1864,9 +1846,12 @@ def local_add_neg_to_sub(fgraph, node):
18641846 return [new_out ]
18651847
18661848 # Check if it is a negative constant
1867- const = get_constant (second )
1868- if const is not None and const < 0 :
1869- new_out = sub (first , np .abs (const ))
1849+ if (
1850+ isinstance (second , TensorConstant )
1851+ and second .unique_value is not None
1852+ and second .unique_value < 0
1853+ ):
1854+ new_out = sub (first , np .abs (second .data ))
18701855 return [new_out ]
18711856
18721857
@@ -1895,7 +1880,12 @@ def local_mul_zero(fgraph, node):
18951880@register_specialize
18961881@node_rewriter ([true_div ])
18971882def local_div_to_reciprocal (fgraph , node ):
1898- if np .all (get_constant (node .inputs [0 ]) == 1.0 ):
1883+ if (
1884+ get_underlying_scalar_constant_value (
1885+ node .inputs [0 ], only_process_constants = True , raise_not_constant = False
1886+ )
1887+ == 1.0
1888+ ):
18991889 out = node .outputs [0 ]
19001890 new_out = reciprocal (local_mul_canonizer .merge_num_denum (node .inputs [1 :], []))
19011891 # The ones could have forced upcasting
@@ -1916,7 +1906,9 @@ def local_reciprocal_canon(fgraph, node):
19161906@register_canonicalize
19171907@node_rewriter ([pt_pow ])
19181908def local_pow_canonicalize (fgraph , node ):
1919- cst = get_constant (node .inputs [1 ])
1909+ cst = get_underlying_scalar_constant_value (
1910+ node .inputs [1 ], only_process_constants = True , raise_not_constant = False
1911+ )
19201912 if cst == 0 :
19211913 return [alloc_like (1 , node .outputs [0 ], fgraph )]
19221914 if cst == 1 :
@@ -1947,7 +1939,12 @@ def local_intdiv_by_one(fgraph, node):
19471939@node_rewriter ([int_div , true_div ])
19481940def local_zero_div (fgraph , node ):
19491941 """0 / x -> 0"""
1950- if get_constant (node .inputs [0 ]) == 0 :
1942+ if (
1943+ get_underlying_scalar_constant_value (
1944+ node .inputs [0 ], only_process_constants = True , raise_not_constant = False
1945+ )
1946+ == 0
1947+ ):
19511948 ret = alloc_like (0 , node .outputs [0 ], fgraph )
19521949 ret .tag .values_eq_approx = values_eq_approx_remove_nan
19531950 return [ret ]
@@ -1960,7 +1957,9 @@ def local_pow_specialize(fgraph, node):
19601957 odtype = node .outputs [0 ].dtype
19611958 xsym = node .inputs [0 ]
19621959 ysym = node .inputs [1 ]
1963- y = get_constant (ysym )
1960+ y = get_underlying_scalar_constant_value (
1961+ ysym , only_process_constants = True , raise_not_constant = False
1962+ )
19641963 if (y is not None ) and not broadcasted_by (xsym , ysym ):
19651964 rval = None
19661965
@@ -1998,7 +1997,9 @@ def local_pow_to_nested_squaring(fgraph, node):
19981997 odtype = node .outputs [0 ].dtype
19991998 xsym = node .inputs [0 ]
20001999 ysym = node .inputs [1 ]
2001- y = get_constant (ysym )
2000+ y = get_underlying_scalar_constant_value (
2001+ ysym , only_process_constants = True , raise_not_constant = False
2002+ )
20022003
20032004 # the next line is needed to fix a strange case that I don't
20042005 # know how to make a separate test.
@@ -2081,7 +2082,9 @@ def local_mul_specialize(fgraph, node):
20812082 nb_neg_node += 1
20822083
20832084 # remove special case arguments of 1, -1 or 0
2084- y = get_constant (inp )
2085+ y = get_underlying_scalar_constant_value (
2086+ inp , raise_not_constant = False , only_process_constants = True
2087+ )
20852088 if y == 1.0 :
20862089 nb_cst += 1
20872090 elif y == - 1.0 :
0 commit comments