Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "QuestBase"
uuid = "7e80f742-43d6-403d-a9ea-981410111d43"
authors = ["Orjan Ameye <[email protected]>", "Jan Kosata <[email protected]>", "Javier del Pino <[email protected]>"]
version = "0.3.1"
version = "0.3.2"

[deps]
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down
11 changes: 6 additions & 5 deletions src/Symbolics/Symbolics_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,20 @@
end

"Return all the terms contained in `x`"
get_all_terms(x::Num) = unique(_get_all_terms(Symbolics.expand(x).val))
get_all_terms(x::Num) = Num.(unique(_get_all_terms(Symbolics.expand(x).val)))
get_all_terms(x::BasicSymbolic) = unique(_get_all_terms(Symbolics.expand(x)))

Check warning on line 92 in src/Symbolics/Symbolics_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/Symbolics/Symbolics_utils.jl#L92

Added line #L92 was not covered by tests
function get_all_terms(x::Equation)
return unique(cat(get_all_terms(Num(x.lhs)), get_all_terms(Num(x.rhs)); dims=1))
end
function _get_all_terms(x::BasicSymbolic)
@compactified x::BasicSymbolic begin
Add => vcat([_get_all_terms(term) for term in SymbolicUtils.arguments(x)]...)
Mul => Num.(SymbolicUtils.arguments(x))
Div => Num.([_get_all_terms(x.num)..., _get_all_terms(x.den)...])
_ => Num(x)
Mul => SymbolicUtils.arguments(x)
Div => [_get_all_terms(x.num)..., _get_all_terms(x.den)...]
_ => [x]
end
end
_get_all_terms(x) = Num(x)
_get_all_terms(x) = x

function is_harmonic(x::Num, t::Num)::Bool
all_terms = get_all_terms(x)
Expand Down
37 changes: 34 additions & 3 deletions src/Symbolics/fourier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
x = simplify_exp_products(x) # simplify products of exps
x = exp_to_trig(x)
x = Num(simplify_complex(expand(x)))
return x# simplify_fractions(x)# (a*c^2 + b*c)/c^2 = (a*c + b)/c
return x # simplify_fractions(x)# (a*c^2 + b*c)/c^2 = (a*c + b)/c
end

"Return true if `f` is a sin or cos."
function is_trig(f::Num)
f = ispow(f.val) ? f.val.base : f.val
is_trig(f::Num) = is_trig(f.val)
is_trig(f) = false
function is_trig(f::BasicSymbolic)
f = ispow(f) ? f.base : f
isterm(f) && SymbolicUtils.operation(f) ∈ [cos, sin] && return true
return false
end
Expand Down Expand Up @@ -148,6 +150,35 @@
convert_to_Num(x::Complex{Num})::Num = Num(first(x.re.val.arguments))
convert_to_Num(x::Num)::Num = x

"""
trig_to_exp(x::BasicSymbolic)

Convert all trigonometric terms (sin, cos) in expression `x` to their exponential form
using Euler's formula: ``\\exp(ix) = \\cos(x) + i*\\sin(x)``.
"""
function trig_to_exp(x::BasicSymbolic)
all_terms = get_all_terms(x)
trigs = filter(z -> is_trig(z), all_terms)

Check warning on line 161 in src/Symbolics/fourier.jl

View check run for this annotation

Codecov / codecov/patch

src/Symbolics/fourier.jl#L159-L161

Added lines #L159 - L161 were not covered by tests

rules = []
for trig in trigs
is_pow = ispow(trig) # trig is either a trig or a power of a trig
power = is_pow ? trig.exp : 1
arg = is_pow ? arguments(trig.base)[1] : arguments(trig)[1]
type = is_pow ? operation(trig.base) : operation(trig)

Check warning on line 168 in src/Symbolics/fourier.jl

View check run for this annotation

Codecov / codecov/patch

src/Symbolics/fourier.jl#L163-L168

Added lines #L163 - L168 were not covered by tests

if type == cos
term = (exp(im * arg) + exp(-im * arg))^power * (1 // 2)^power
elseif type == sin
term =

Check warning on line 173 in src/Symbolics/fourier.jl

View check run for this annotation

Codecov / codecov/patch

src/Symbolics/fourier.jl#L170-L173

Added lines #L170 - L173 were not covered by tests
(1 * im^power) * ((exp(-im * arg) - exp(im * arg)))^power * (1 // 2)^power
end

append!(rules, [trig => term])
end
return Symbolics.substitute(x, Dict(rules))

Check warning on line 179 in src/Symbolics/fourier.jl

View check run for this annotation

Codecov / codecov/patch

src/Symbolics/fourier.jl#L177-L179

Added lines #L177 - L179 were not covered by tests
end

"""
exp_to_trig(x::BasicSymbolic)
exp_to_trig(x)
Expand Down
52 changes: 36 additions & 16 deletions test/symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ end
@eqtest max_power(a^2 + b, a) == 2
@eqtest max_power(a * ((a + b)^4)^2 + a, a) == 9
@eqtest max_power([a * ((a + b)^4)^2 + a, a^2], a) == 9
@eqtest max_power(a + im*a^2, a) == 2
@eqtest max_power(a + im * a^2, a) == 2

@eqtest drop_powers(a^2 + b, a, 1) == b
@eqtest drop_powers((a + b)^2, a, 1) == b^2
Expand All @@ -52,29 +52,49 @@ end
# eq = drop_powers(a^2 + a ~ b, [a, b], 2) # broken
@eqtest [eq.lhs, eq.rhs] == [a, a]
eq = drop_powers(a^2 + a + b ~ a, a, 2)
@test string(eq.rhs) == "a" broken=true
@test string(eq.rhs) == "a" broken = true

@eqtest drop_powers([a^2 + a + b, b], a, 2) == [a + b, b]
@eqtest drop_powers([a^2 + a + b, b], [a, b], 2) == [a + b, b]
end

@testset "trig_to_exp and trig_to_exp" begin
using QuestBase: expand_all, trig_to_exp, exp_to_trig
@variables f t
cos_euler(x) = (exp(im * x) + exp(-im * x)) / 2
sin_euler(x) = (exp(im * x) - exp(-im * x)) / (2 * im)

# automatic conversion between trig and exp form
trigs = [cos(f * t), sin(f * t)]
for (i, trig) in pairs(trigs)
z = trig_to_exp(trig)
@eqtest expand(exp_to_trig(z)) == trig
end
trigs′ = [cos_euler(f * t), sin_euler(f * t)]
for (i, trig) in pairs(trigs′)
z = trig_to_exp(trig)
@eqtest expand(exp_to_trig(z)) == trigs[i]
@testset "Num" begin
@variables f t
cos_euler(x) = (exp(im * x) + exp(-im * x)) / 2
sin_euler(x) = (exp(im * x) - exp(-im * x)) / (2 * im)

# automatic conversion between trig and exp form
trigs = [cos(f * t), sin(f * t)]
for (i, trig) in pairs(trigs)
z = trig_to_exp(trig)
@eqtest expand(exp_to_trig(z)) == trig
end
trigs′ = [cos_euler(f * t), sin_euler(f * t)]
for (i, trig) in pairs(trigs′)
z = trig_to_exp(trig)
@eqtest expand(exp_to_trig(z)) == trigs[i]
end
end

# @testset "BasicSymbolic" begin
# @syms f t
# cos_euler(x) = (exp(im * x) + exp(-im * x)) / 2
# sin_euler(x) = (exp(im * x) - exp(-im * x)) / (2 * im)

# # automatic conversion between trig and exp form
# trigs = [cos(f * t), sin(f * t)]
# for (i, trig) in pairs(trigs)
# z = trig_to_exp(trig)
# @eqtest expand(exp_to_trig(z)) == trig
# end
# trigs′ = [cos_euler(f * t), sin_euler(f * t)]
# for (i, trig) in pairs(trigs′)
# z = trig_to_exp(trig)
# @eqtest expand(exp_to_trig(z)) == trigs[i]
# end
# end
end

@testset "harmonic" begin
Expand Down
Loading