@@ -870,13 +870,33 @@ function hashcons(s::BSImpl.Type)
870870 end true
871871end
872872
873- const SMALLV_DEFAULT_SYMREAL = hashcons (BSImpl. Const {SymReal} (0 , 0 , nothing ))
874- const SMALLV_DEFAULT_SAFEREAL = hashcons (BSImpl. Const {SafeReal} (0 , 0 , nothing ))
875- const SMALLV_DEFAULT_TREEREAL = hashcons (BSImpl. Const {TreeReal} (0 , 0 , nothing ))
873+ const CONST_ZERO_SYMREAL = hashcons (BSImpl. Const {SymReal} (0 , 0 , nothing ))
874+ const CONST_ZERO_SAFEREAL = hashcons (BSImpl. Const {SafeReal} (0 , 0 , nothing ))
875+ const CONST_ZERO_TREEREAL = hashcons (BSImpl. Const {TreeReal} (0 , 0 , nothing ))
876+ const CONST_ONE_SYMREAL = hashcons (BSImpl. Const {SymReal} (1 , 0 , nothing ))
877+ const CONST_ONE_SAFEREAL = hashcons (BSImpl. Const {SafeReal} (1 , 0 , nothing ))
878+ const CONST_ONE_TREEREAL = hashcons (BSImpl. Const {TreeReal} (1 , 0 , nothing ))
876879
877- defaultval (:: Type{BasicSymbolic{SymReal}} ) = SMALLV_DEFAULT_SYMREAL
878- defaultval (:: Type{BasicSymbolic{SafeReal}} ) = SMALLV_DEFAULT_SAFEREAL
879- defaultval (:: Type{BasicSymbolic{TreeReal}} ) = SMALLV_DEFAULT_TREEREAL
880+ @inline defaultval (:: Type{BasicSymbolic{SymReal}} ) = CONST_ZERO_SYMREAL
881+ @inline defaultval (:: Type{BasicSymbolic{SafeReal}} ) = CONST_ZERO_SAFEREAL
882+ @inline defaultval (:: Type{BasicSymbolic{TreeReal}} ) = CONST_ZERO_TREEREAL
883+
884+ """
885+ $(TYPEDSIGNATURES)
886+
887+ Return a `Const` representing `0` with the provided `vartype`.
888+ """
889+ @inline zero_of_vartype (:: Type{SymReal} ) = CONST_ZERO_SYMREAL
890+ @inline zero_of_vartype (:: Type{SafeReal} ) = CONST_ZERO_SAFEREAL
891+ @inline zero_of_vartype (:: Type{TreeReal} ) = CONST_ZERO_TREEREAL
892+ """
893+ $(TYPEDSIGNATURES)
894+
895+ Return a `Const` representing `1` with the provided `vartype`.
896+ """
897+ @inline one_of_vartype (:: Type{SymReal} ) = CONST_ONE_SYMREAL
898+ @inline one_of_vartype (:: Type{SafeReal} ) = CONST_ONE_SAFEREAL
899+ @inline one_of_vartype (:: Type{TreeReal} ) = CONST_ONE_TREEREAL
880900
881901function get_mul_coefficient (x)
882902 iscall (x) && operation (x) === (* ) || throw (ArgumentError (" $x is not a multiplication" ))
@@ -1151,7 +1171,7 @@ end
11511171 if isempty (dict)
11521172 return Const {T} (coeff)
11531173 elseif _iszero (coeff)
1154- return Const {T} ( 0 )
1174+ return zero_of_vartype (T )
11551175 elseif _isone (coeff) && length (dict) == 1
11561176 k, v = first (dict)
11571177 if _isone (v)
@@ -2057,7 +2077,7 @@ function (awb::AddWorkerBuffer{T})(terms::Union{Tuple{Vararg{BasicSymbolic{T}}},
20572077 if ! all (_numeric_or_arrnumeric_symtype, terms)
20582078 throw (MethodError (+ , Tuple (terms)))
20592079 end
2060- isempty (terms) && return Const {T} ( 0 )
2080+ isempty (terms) && return zero_of_vartype (T )
20612081 if isone (length (terms))
20622082 return Const {T} (only (terms))
20632083 end
@@ -2246,7 +2266,7 @@ end
22462266
22472267function _split_arrterm_scalar_coeff (:: Type{T} , ex:: BasicSymbolic{T} ) where {T}
22482268 sh = shape (ex)
2249- _is_array_shape (sh) || return ex, Const {T} ( 1 )
2269+ _is_array_shape (sh) || return ex, one_of_vartype (T )
22502270 @match ex begin
22512271 BSImpl. Term (; f, args, type) && if f === (* ) && ! _is_array_shape (shape (first (args))) end => begin
22522272 if length (args) == 2
@@ -2264,7 +2284,7 @@ function _split_arrterm_scalar_coeff(::Type{T}, ex::BasicSymbolic{T}) where {T}
22642284 coeff, rest = @match expr begin
22652285 BSImpl. Term (; f, args, type, shape) && if f === (* ) end => begin
22662286 if query! (isequal (idxs_for_arrayop (T)), args[1 ])
2267- Const {T} ( 1 ), expr
2287+ one_of_vartype (T ), expr
22682288 elseif length (args) == 2
22692289 args[1 ], args[2 ]
22702290 else
@@ -2274,7 +2294,7 @@ function _split_arrterm_scalar_coeff(::Type{T}, ex::BasicSymbolic{T}) where {T}
22742294 _coeff, BSImpl. Term {T} (* , newargs; type, shape)
22752295 end
22762296 end
2277- _ => (Const {T} ( 1 ), expr)
2297+ _ => (one_of_vartype (T ), expr)
22782298 end
22792299 if term === nothing
22802300 termrest = nothing
@@ -2284,15 +2304,15 @@ function _split_arrterm_scalar_coeff(::Type{T}, ex::BasicSymbolic{T}) where {T}
22842304 end
22852305 return coeff, BSImpl. ArrayOp {T} (output_idx, rest, reduce, termrest, ranges; shape, type)
22862306 end
2287- _ => (Const {T} ( 1 ), ex)
2307+ _ => (one_of_vartype (T ), ex)
22882308 end
22892309end
2290- _split_arrterm_scalar_coeff (:: Type{T} , ex) where {T} = Const {T} ( 1 ), Const {T} (ex)
2310+ _split_arrterm_scalar_coeff (:: Type{T} , ex) where {T} = one_of_vartype (T ), Const {T} (ex)
22912311
22922312function _as_base_exp (term:: BasicSymbolic{T} ) where {T}
22932313 @match term begin
22942314 BSImpl. Term (; f, args) && if f === (^ ) && isconst (args[2 ]) end => (args[1 ], args[2 ])
2295- _ => (term, Const {T} ( 1 ))
2315+ _ => (term, one_of_vartype (T ))
22962316 end
22972317end
22982318
@@ -2355,7 +2375,7 @@ function (mwb::MulWorkerBuffer{T})(terms) where {T}
23552375 if ! all (x -> _is_array_of_symbolics (x) || _numeric_or_arrnumeric_symtype (x), terms)
23562376 throw (MethodError (* , Tuple (terms)))
23572377 end
2358- isempty (terms) && return Const {T} ( 1 )
2378+ isempty (terms) && return one_of_vartype (T )
23592379 length (terms) == 1 && return Const {T} (terms[1 ])
23602380 empty! (mwb)
23612381 newshape = _multiplied_terms_shape (terms)
@@ -2827,7 +2847,7 @@ function ^(a::BasicSymbolic{T}, b) where {T <: Union{SymReal, SafeReal}}
28272847 return Term {T} (^ , ArgsT {T} ((a, Const {T} (b))); type, shape = newshape):: BasicSymbolic{T}
28282848 end
28292849 if b isa Number
2830- iszero (b) && return Const {T} ( 1 )
2850+ iszero (b) && return one_of_vartype (T )
28312851 isone (b) && return Const {T} (a)
28322852 end
28332853 if b isa Real && b < 0
0 commit comments