Skip to content

Commit 894efcf

Browse files
feat: add and use global Const(1) and Const(0) values
1 parent abc207d commit 894efcf

File tree

3 files changed

+47
-27
lines changed

3 files changed

+47
-27
lines changed

src/methods.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ function _copy_broadcast!(buffer::BroadcastBuffer{T}, bc::Broadcast.Broadcasted{
701701
# unknown ndims, assume full shape
702702
limit = sh.ndims == -1 ? ndim : sh.ndims
703703
for i in 1:limit
704-
push!(getindex_args, length(bc.axes[i]) == 1 ? Const{T}(1) : subscripts[i])
704+
push!(getindex_args, length(bc.axes[i]) == 1 ? one_of_vartype(T) : subscripts[i])
705705
end
706706
elseif sh isa ShapeVecT
707707
for (i, (target_ax, cur_ax)) in enumerate(zip(bc.axes, sh))

src/polyform.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ function quick_cancel(x::S, y::S)::Tuple{S, S} where {T <: SymVariant, S <: Basi
258258
elseif opy === (*) && !isconst(x)
259259
return reverse(quick_mul(y, x))
260260
elseif isequal(x, y)
261-
return Const{T}(1), Const{T}(1)
261+
return one_of_vartype(T), one_of_vartype(T)
262262
else
263263
return x, y
264264
end
@@ -269,7 +269,7 @@ function quick_pow(x::S, y::S)::Tuple{S, S} where {T <: SymVariant, S <: BasicSy
269269
base, exp = arguments(x)
270270
exp = unwrap_const(exp)
271271
exp isa Number || return (x, y)
272-
isequal(base, y) && exp >= 1 ? (base ^ (exp - 1), Const{T}(1)) : (x, y)
272+
isequal(base, y) && exp >= 1 ? (base ^ (exp - 1), one_of_vartype(T)) : (x, y)
273273
end
274274

275275
# Double Pow case
@@ -281,17 +281,17 @@ function quick_powpow(x::S, y::S)::Tuple{S, S} where {T <: SymVariant, S <: Basi
281281
exp2 = unwrap_const(exp2)
282282
!(exp1 isa Number && exp2 isa Number) && return (x, y)
283283
if exp1 > exp2
284-
return base1 ^ (exp1 - exp2), Const{T}(1)
284+
return base1 ^ (exp1 - exp2), one_of_vartype(T)
285285
elseif exp1 == exp2
286-
return Const{T}(1), Const{T}(1)
286+
return one_of_vartype(T), one_of_vartype(T)
287287
else # exp1 < exp2
288-
return Const{T}(1), base2 ^ (exp2 - exp1)
288+
return one_of_vartype(T), base2 ^ (exp2 - exp1)
289289
end
290290
end
291291

292292
# ismul(x)
293293
function quick_mul(x::S, y::S)::Tuple{S, S} where {T <: SymVariant, S <: BasicSymbolic{T}}
294-
yy = BSImpl.Term{T}(^, ArgsT{T}((y, Const{T}(1))); type = symtype(y))
294+
yy = BSImpl.Term{T}(^, ArgsT{T}((y, one_of_vartype(T))); type = symtype(y))
295295
newx, newy = quick_mulpow(x, yy)
296296
return isequal(newy, yy) ? (x, y) : (newx, newy)
297297
end
@@ -326,12 +326,12 @@ function quick_mulpow(x::S, y::S)::Tuple{S, S} where {T <: SymVariant, S <: Basi
326326
oldval = args[idx]
327327
if argexp > exp
328328
args[idx] = argbase ^ (argexp - exp)
329-
result = mul_worker(T, args), Const{T}(1)
329+
result = mul_worker(T, args), one_of_vartype(T)
330330
elseif argexp == exp
331-
args[idx] = Const{T}(1)
332-
result = mul_worker(T, args), Const{T}(1)
331+
args[idx] = one_of_vartype(T)
332+
result = mul_worker(T, args), one_of_vartype(T)
333333
else
334-
args[idx] = Const{T}(1)
334+
args[idx] = one_of_vartype(T)
335335
result = mul_worker(T, args), base ^ (exp - argexp)
336336
end
337337
args[idx] = oldval

src/types.jl

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -870,13 +870,33 @@ function hashcons(s::BSImpl.Type)
870870
end true
871871
end
872872

873-
const SMALLV_DEFAULT_SYMREAL = hashcons(BSImpl.Const{SymReal}(0, 0, nothing))
874-
const SMALLV_DEFAULT_SAFEREAL = hashcons(BSImpl.Const{SafeReal}(0, 0, nothing))
875-
const SMALLV_DEFAULT_TREEREAL = hashcons(BSImpl.Const{TreeReal}(0, 0, nothing))
873+
const CONST_ZERO_SYMREAL = hashcons(BSImpl.Const{SymReal}(0, 0, nothing))
874+
const CONST_ZERO_SAFEREAL = hashcons(BSImpl.Const{SafeReal}(0, 0, nothing))
875+
const CONST_ZERO_TREEREAL = hashcons(BSImpl.Const{TreeReal}(0, 0, nothing))
876+
const CONST_ONE_SYMREAL = hashcons(BSImpl.Const{SymReal}(1, 0, nothing))
877+
const CONST_ONE_SAFEREAL = hashcons(BSImpl.Const{SafeReal}(1, 0, nothing))
878+
const CONST_ONE_TREEREAL = hashcons(BSImpl.Const{TreeReal}(1, 0, nothing))
876879

877-
defaultval(::Type{BasicSymbolic{SymReal}}) = SMALLV_DEFAULT_SYMREAL
878-
defaultval(::Type{BasicSymbolic{SafeReal}}) = SMALLV_DEFAULT_SAFEREAL
879-
defaultval(::Type{BasicSymbolic{TreeReal}}) = SMALLV_DEFAULT_TREEREAL
880+
@inline defaultval(::Type{BasicSymbolic{SymReal}}) = CONST_ZERO_SYMREAL
881+
@inline defaultval(::Type{BasicSymbolic{SafeReal}}) = CONST_ZERO_SAFEREAL
882+
@inline defaultval(::Type{BasicSymbolic{TreeReal}}) = CONST_ZERO_TREEREAL
883+
884+
"""
885+
$(TYPEDSIGNATURES)
886+
887+
Return a `Const` representing `0` with the provided `vartype`.
888+
"""
889+
@inline zero_of_vartype(::Type{SymReal}) = CONST_ZERO_SYMREAL
890+
@inline zero_of_vartype(::Type{SafeReal}) = CONST_ZERO_SAFEREAL
891+
@inline zero_of_vartype(::Type{TreeReal}) = CONST_ZERO_TREEREAL
892+
"""
893+
$(TYPEDSIGNATURES)
894+
895+
Return a `Const` representing `1` with the provided `vartype`.
896+
"""
897+
@inline one_of_vartype(::Type{SymReal}) = CONST_ONE_SYMREAL
898+
@inline one_of_vartype(::Type{SafeReal}) = CONST_ONE_SAFEREAL
899+
@inline one_of_vartype(::Type{TreeReal}) = CONST_ONE_TREEREAL
880900

881901
function get_mul_coefficient(x)
882902
iscall(x) && operation(x) === (*) || throw(ArgumentError("$x is not a multiplication"))
@@ -1151,7 +1171,7 @@ end
11511171
if isempty(dict)
11521172
return Const{T}(coeff)
11531173
elseif _iszero(coeff)
1154-
return Const{T}(0)
1174+
return zero_of_vartype(T)
11551175
elseif _isone(coeff) && length(dict) == 1
11561176
k, v = first(dict)
11571177
if _isone(v)
@@ -2057,7 +2077,7 @@ function (awb::AddWorkerBuffer{T})(terms::Union{Tuple{Vararg{BasicSymbolic{T}}},
20572077
if !all(_numeric_or_arrnumeric_symtype, terms)
20582078
throw(MethodError(+, Tuple(terms)))
20592079
end
2060-
isempty(terms) && return Const{T}(0)
2080+
isempty(terms) && return zero_of_vartype(T)
20612081
if isone(length(terms))
20622082
return Const{T}(only(terms))
20632083
end
@@ -2246,7 +2266,7 @@ end
22462266

22472267
function _split_arrterm_scalar_coeff(::Type{T}, ex::BasicSymbolic{T}) where {T}
22482268
sh = shape(ex)
2249-
_is_array_shape(sh) || return ex, Const{T}(1)
2269+
_is_array_shape(sh) || return ex, one_of_vartype(T)
22502270
@match ex begin
22512271
BSImpl.Term(; f, args, type) && if f === (*) && !_is_array_shape(shape(first(args))) end => begin
22522272
if length(args) == 2
@@ -2264,7 +2284,7 @@ function _split_arrterm_scalar_coeff(::Type{T}, ex::BasicSymbolic{T}) where {T}
22642284
coeff, rest = @match expr begin
22652285
BSImpl.Term(; f, args, type, shape) && if f === (*) end => begin
22662286
if query!(isequal(idxs_for_arrayop(T)), args[1])
2267-
Const{T}(1), expr
2287+
one_of_vartype(T), expr
22682288
elseif length(args) == 2
22692289
args[1], args[2]
22702290
else
@@ -2274,7 +2294,7 @@ function _split_arrterm_scalar_coeff(::Type{T}, ex::BasicSymbolic{T}) where {T}
22742294
_coeff, BSImpl.Term{T}(*, newargs; type, shape)
22752295
end
22762296
end
2277-
_ => (Const{T}(1), expr)
2297+
_ => (one_of_vartype(T), expr)
22782298
end
22792299
if term === nothing
22802300
termrest = nothing
@@ -2284,15 +2304,15 @@ function _split_arrterm_scalar_coeff(::Type{T}, ex::BasicSymbolic{T}) where {T}
22842304
end
22852305
return coeff, BSImpl.ArrayOp{T}(output_idx, rest, reduce, termrest, ranges; shape, type)
22862306
end
2287-
_ => (Const{T}(1), ex)
2307+
_ => (one_of_vartype(T), ex)
22882308
end
22892309
end
2290-
_split_arrterm_scalar_coeff(::Type{T}, ex) where {T} = Const{T}(1), Const{T}(ex)
2310+
_split_arrterm_scalar_coeff(::Type{T}, ex) where {T} = one_of_vartype(T), Const{T}(ex)
22912311

22922312
function _as_base_exp(term::BasicSymbolic{T}) where {T}
22932313
@match term begin
22942314
BSImpl.Term(; f, args) && if f === (^) && isconst(args[2]) end => (args[1], args[2])
2295-
_ => (term, Const{T}(1))
2315+
_ => (term, one_of_vartype(T))
22962316
end
22972317
end
22982318

@@ -2355,7 +2375,7 @@ function (mwb::MulWorkerBuffer{T})(terms) where {T}
23552375
if !all(x -> _is_array_of_symbolics(x) || _numeric_or_arrnumeric_symtype(x), terms)
23562376
throw(MethodError(*, Tuple(terms)))
23572377
end
2358-
isempty(terms) && return Const{T}(1)
2378+
isempty(terms) && return one_of_vartype(T)
23592379
length(terms) == 1 && return Const{T}(terms[1])
23602380
empty!(mwb)
23612381
newshape = _multiplied_terms_shape(terms)
@@ -2827,7 +2847,7 @@ function ^(a::BasicSymbolic{T}, b) where {T <: Union{SymReal, SafeReal}}
28272847
return Term{T}(^, ArgsT{T}((a, Const{T}(b))); type, shape = newshape)::BasicSymbolic{T}
28282848
end
28292849
if b isa Number
2830-
iszero(b) && return Const{T}(1)
2850+
iszero(b) && return one_of_vartype(T)
28312851
isone(b) && return Const{T}(a)
28322852
end
28332853
if b isa Real && b < 0

0 commit comments

Comments
 (0)