Skip to content

Commit 11dedcc

Browse files
refactor: improve division performance
1 parent f509a28 commit 11dedcc

File tree

5 files changed

+94
-49
lines changed

5 files changed

+94
-49
lines changed

src/polyform.jl

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,11 @@ function quick_cancel(d::BasicSymbolic{T})::BasicSymbolic{T} where {T}
239239
end
240240

241241
function quick_cancel(x::S, y::S)::Tuple{S, S} where {T <: SymVariant, S <: BasicSymbolic{T}}
242+
isequal(x, y) && return one_of_vartype(T), one_of_vartype(T)
242243
opx = iscall(x) ? operation(x) : nothing
243244
opy = iscall(y) ? operation(y) : nothing
245+
icx = isconst(x)
246+
icy = isconst(y)
244247
if opx === (^) && opy === (^)
245248
return quick_powpow(x, y)
246249
elseif opx === (*) && opy === (^)
@@ -249,16 +252,14 @@ function quick_cancel(x::S, y::S)::Tuple{S, S} where {T <: SymVariant, S <: Basi
249252
return reverse(quick_mulpow(y, x))
250253
elseif opx === (*) && opy === (*)
251254
return quick_mulmul(x, y)
252-
elseif opx === (^) && !isconst(y)
255+
elseif opx === (^) && !icy
253256
return quick_pow(x, y)
254-
elseif opy === (^) && !isconst(x)
257+
elseif opy === (^) && !icx
255258
return reverse(quick_pow(y, x))
256-
elseif opx === (*) && !isconst(y)
259+
elseif opx === (*) && !icy
257260
return quick_mul(x, y)
258-
elseif opy === (*) && !isconst(x)
261+
elseif opy === (*) && !icx
259262
return reverse(quick_mul(y, x))
260-
elseif isequal(x, y)
261-
return one_of_vartype(T), one_of_vartype(T)
262263
else
263264
return x, y
264265
end
@@ -340,20 +341,31 @@ end
340341

341342
# Double mul case
342343
function quick_mulmul(x::S, y::S)::Tuple{S, S} where {T <: SymVariant, S <: BasicSymbolic{T}}
343-
yargs = arguments(y)
344-
for (i, arg) in enumerate(yargs)
345-
newx, newarg = quick_cancel(x, arg)
346-
isequal(arg, newarg) && continue
347-
if yargs isa ROArgsT
348-
yargs = copy(parent(yargs))
344+
@match (x, y) begin
345+
(BSImpl.AddMul(; coeff = c1, dict = d1, type = t1, shape = s1, variant = vr1), BSImpl.AddMul(; coeff = c2, dict = d2, type = t2, shape = s2, variant = vr2)) => begin
346+
newd1 = d1
347+
newd2 = d2
348+
for (k1, v1) in d1
349+
haskey(d2, k1) || continue
350+
v2 = d2[k1]
351+
if newd1 === d1
352+
newd1 = copy(d1)
353+
newd2 = copy(d2)
354+
end
355+
delete!((v1 >= v2) ? newd2 : newd1, k1)
356+
setindex!((v1 >= v2) ? newd1 : newd2, abs(v1 - v2), k1)
357+
end
358+
if newd1 === d1
359+
return x, y
360+
end
361+
filter!(!iszero last, newd1)
362+
filter!(!iszero last, newd2)
363+
xx = Mul{T}(c1, newd1; type = t1, shape = s1)
364+
yy = Mul{T}(c2, newd2; type = t2, shape = s2)
365+
366+
return xx, yy
349367
end
350-
yargs[i] = Const{T}(newarg)
351-
x = newx
352-
end
353-
if yargs isa ROArgsT
354-
return x, y
355-
else
356-
return x, mul_worker(T, yargs)
368+
_ => _unreachable()
357369
end
358370
end
359371

src/types.jl

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,7 @@ end
10921092
end
10931093

10941094
@inline function BSImpl.AddMul{T}(coeff, dict, variant::AddMulVariant.T; metadata = nothing, type, shape = default_shape(type), unsafe = false) where {T}
1095+
@nospecialize coeff
10951096
metadata = parse_metadata(metadata)
10961097
shape = parse_shape(shape)
10971098
dict = parse_dict(T, dict)
@@ -1156,6 +1157,7 @@ struct ArrayOp{T} end
11561157
end
11571158

11581159
@inline function Add{T}(coeff, dict; kw...) where {T}
1160+
@nospecialize coeff kw
11591161
coeff = unwrap(coeff)
11601162
dict = unwrap_dict(dict)
11611163
if isempty(dict)
@@ -1173,6 +1175,7 @@ end
11731175
end
11741176

11751177
@inline function Mul{T}(coeff, dict; kw...) where {T}
1178+
@nospecialize coeff kw
11761179
coeff = unwrap(coeff)
11771180
dict = unwrap_dict(dict)
11781181
if isempty(dict)
@@ -1190,7 +1193,7 @@ end
11901193
k, v = first(dict)
11911194
if _isone(v)
11921195
@match k begin
1193-
BSImpl.AddMul(; coeff = c2, dict = d2, variant) && if variant == AddMulVariant.ADD end => begin
1196+
BSImpl.AddMul(; coeff = c2, dict = d2, variant) && if variant === AddMulVariant.ADD end => begin
11941197
empty!(dict)
11951198
for (k, v) in d2
11961199
dict[k] = -v
@@ -1231,31 +1234,31 @@ end
12311234
Simplify the coefficients of `n` and `d` (numerator and denominator).
12321235
"""
12331236
function simplify_coefficients(n, d)
1234-
if safe_isinteger(n)
1235-
n = Int(n)
1237+
return n, d
1238+
end
1239+
1240+
function safe_div(a::Number, b::Number)::Number
1241+
# @nospecialize a b
1242+
if (!(a isa Integer) && safe_isinteger(a))
1243+
a = Int(a)
12361244
end
1237-
if safe_isinteger(d)
1238-
d = Int(d)
1245+
if (!(b isa Integer) && safe_isinteger(b))
1246+
b = Int(b)
12391247
end
1240-
nrat, nc = ratcoeff(n)
1241-
drat, dc = ratcoeff(d)
1242-
nrat && drat || return n, d
1243-
g = gcd(nc, dc) * sign(dc) # make denominator positive
1244-
invdc = isone(g) ? g : (1 // g)
1245-
n = maybe_integer(invdc * n)
1246-
d = maybe_integer(invdc * d)
1247-
1248-
return n, d
1248+
if a isa Integer && b isa Integer
1249+
return a // b
1250+
end
1251+
return a / b
12491252
end
12501253

12511254
"""
12521255
$(TYPEDSIGNATURES)
12531256
"""
12541257
function Div{T}(n, d, simplified; type = promote_symtype(/, symtype(n), symtype(d)), kw...) where {T}
1255-
n = unwrap(n)
1256-
d = unwrap(d)
1258+
n = Const{T}(unwrap(n))
1259+
d = Const{T}(unwrap(d))
12571260

1258-
if !(type <: Number)
1261+
if !_numeric_or_arrnumeric_type(type)
12591262
_iszero(n) && return Const{T}(n)
12601263
_isone(d) && return Const{T}(n)
12611264
return BSImpl.Div{T}(n, d, simplified; type, kw...)
@@ -1264,8 +1267,6 @@ function Div{T}(n, d, simplified; type = promote_symtype(/, symtype(n), symtype(
12641267
if !(T === SafeReal)
12651268
n, d = quick_cancel(Const{T}(n), Const{T}(d))
12661269
end
1267-
n = unwrap_const(n)
1268-
d = unwrap_const(d)
12691270

12701271
_iszero(n) && return Const{T}(n)
12711272
_isone(d) && return Const{T}(n)
@@ -1279,10 +1280,42 @@ function Div{T}(n, d, simplified; type = promote_symtype(/, symtype(n), symtype(
12791280
return Div{T}(n * d.den, d.num, simplified; type, kw...)
12801281
end
12811282

1282-
d isa Number && _isone(-d) && return Const{T}(-n)
1283-
n isa Rat && d isa Rat && return Const{T}(n // d)
1283+
isconst(d) && _isone(-d) && return Const{T}(-n)
1284+
if isconst(n) && isconst(d)
1285+
nn = unwrap_const(n)
1286+
dd = unwrap_const(d)
1287+
nn isa Rat && dd isa Rat && return Const{T}(nn // dd)
1288+
return Const{T}(nn / dd)
1289+
end
12841290

1285-
n, d = simplify_coefficients(n, d)
1291+
@match (n, d) begin
1292+
(BSImpl.Const(; val = v1), BSImpl.Const(; val = v2)) => Const{T}(safe_div(v1, v2))
1293+
(BSImpl.Const(; val = c1), BSImpl.AddMul(; coeff = c2, dict, type, shape, variant)) && if variant == AddMulVariant.MUL end => begin
1294+
if c1 isa Rat && c2 isa Rat
1295+
tmp = c1 // c2
1296+
c1 = tmp.num
1297+
c2 = tmp.den
1298+
end
1299+
n = Const{T}(c1)
1300+
d = Mul{T}(c2, dict, ; type, shape)
1301+
end
1302+
(BSImpl.AddMul(; coeff, dict, type, shape, variant), BSImpl.Const(; val)) && if variant == AddMulVariant.MUL end => begin
1303+
return Mul{T}(safe_div(coeff, val), dict, ; type, shape)
1304+
end
1305+
(BSImpl.AddMul(; coeff = c1, dict = d1, type = t1, shape = sh1, variant = v1),
1306+
BSImpl.AddMul(; coeff = c2, dict = d2, type = t2, shape = sh2, variant = v2)) &&
1307+
if v1 == AddMulVariant.MUL && v2 == AddMulVariant.MUL end => begin
1308+
1309+
if c1 isa Rat && c2 isa Rat
1310+
tmp = c1 // c2
1311+
c1 = tmp.num
1312+
c2 = tmp.den
1313+
end
1314+
n = Mul{T}(c1, d1, ; type = t1, shape = sh1)
1315+
d = Mul{T}(c2, d2, ; type = t2, shape = sh2)
1316+
end
1317+
_ => nothing
1318+
end
12861319

12871320
_isone(d) && return Const{T}(n)
12881321
_isone(-d) && return Const{T}(-n)

src/utils.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Base: ImmutableDict
22

3-
safe_isinteger(x::Number) = isinteger(x) && abs(x) < typemax(Int)
3+
safe_isinteger(@nospecialize(x::Number)) = isinteger(x) && abs(x) < typemax(Int)
44
safe_isinteger(x) = false
55

66
pow(x,y) = y==0 ? 1 : y<0 ? inv(x)^(-y) : x^y
@@ -29,13 +29,15 @@ isliteral(::Type{T}) where {T} = x -> x isa T
2929
is_literal_number(x) = isliteral(Number)(unwrap_const(x))
3030

3131
# checking the type directly is faster than dynamic dispatch in type unstable code
32-
function _iszero(x)
32+
@cache function _iszero(x)::Bool
33+
@nospecialize x
3334
x = unwrap_const(unwrap(x))
3435
x isa Number && return iszero(x)::Bool
3536
x isa Array && return iszero(x)::Bool
3637
return false
3738
end
38-
function _isone(x)
39+
@cache function _isone(x)::Bool
40+
@nospecialize x
3941
x = unwrap_const(unwrap(x))
4042
x isa Number && return isone(x)::Bool
4143
x isa Array && return isone(x)::Bool

test/basics.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,8 @@ end
185185
@test isequal(w == 0, Term{SymReal}(==, [w, 0]; type = Bool))
186186

187187
@syms x::Integer a::Integer
188-
@eqtest x // 5 == (1 // 5) * x
188+
@eqtest x // 5 == SymbolicUtils.Div{SymReal}(x, 5, false; type = Real)
189189
@eqtest (1//2 * x) / 5 == (1 // 10) * x
190-
@eqtest x // Int16(5) == Rational{Int16}(1, 5) * x
191190
@eqtest 5 // x == 5 / x
192191
@eqtest x // a == x / a
193192

@@ -1086,8 +1085,7 @@ end
10861085
@test get_mul_coefficient((2x/3y).num) == 2
10871086
@test get_mul_coefficient((2x/3y).den) == 3
10881087
@test unwrap_const(2x/-3x) == -2//3
1089-
@test unwrap_const((2.5x/3x).num) == 2.5
1090-
@test unwrap_const((2.5x/3x).den) == 3
1088+
@test unwrap_const((2.5x/3x)) == 2.5/3
10911089
@test unwrap_const(x/3x) == 1//3
10921090
@test isequal(x / 1, x)
10931091
@test isequal(x / -1, -x)

test/polyform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ end
6464
@test isequal(simplify_fractions(a), 7/expand(-(x-2)^2))
6565

6666
# https://github.com/JuliaSymbolics/Symbolics.jl/issues/968
67-
@eqtest simplify_fractions((x * y + (1//2) * x) / (2 * x)) == (1 + 2y) / 4
67+
@eqtest simplify_fractions((x * y + (1//2) * x) / (2 * x)) == (1//2 + y) / 2
6868
end
6969

7070
@testset "isone iszero" begin

0 commit comments

Comments
 (0)