Skip to content

Commit a9f3a99

Browse files
committed
fix pow on termcombination and use similarterm to avoid defining <_e
1 parent 9db1674 commit a9f3a99

File tree

4 files changed

+15
-8
lines changed

4 files changed

+15
-8
lines changed

src/ModelingToolkit.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ using RecursiveArrayTools
2525
import SymbolicUtils
2626
import SymbolicUtils: Term, Add, Mul, Pow, Sym, to_symbolic, FnType, @rule, Rewriters, substitute, similarterm
2727

28+
import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint
29+
2830
using LinearAlgebra: LU, BlasInt
2931

3032
import LightGraphs: SimpleDiGraph, add_edge!
@@ -88,7 +90,7 @@ end
8890
<(s::Num, x::Num) = value(s) <value(x)
8991

9092
for T in (Integer, Rational)
91-
@eval Base.:(^)(n::Num, i::$T) = Num(Term{symtype(n)}(^, [value(n),i]))
93+
@eval Base.:(^)(n::Num, i::$T) = Num(value(n)^i)
9294
end
9395

9496
macro num_method(f, expr, Ts=nothing)

src/direct.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ let
128128

129129
_scalar = one(TermCombination)
130130

131-
linearity_propagator = [
131+
simterm(t, f, args) = Term{Any}(f, args)
132+
linearity_rules = [
132133
@rule +(~~xs) => reduce(+, filter(isidx, ~~xs), init=_scalar)
133134
@rule *(~~xs) => reduce(*, filter(isidx, ~~xs), init=_scalar)
134135
@rule (~f)(~x::(!isidx)) => _scalar
@@ -146,7 +147,8 @@ let
146147
else
147148
error("Function of unknown linearity used: ", ~f)
148149
end
149-
end] |> Rewriters.Chain |> Rewriters.Postwalk |> Rewriters.Fixpoint
150+
end]
151+
linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); similarterm=simterm))
150152

151153
global hessian_sparsity
152154

@@ -164,7 +166,7 @@ let
164166
u = map(value, u)
165167
idx(i) = TermCombination(Set([Dict(i=>1)]))
166168
dict = Dict(u .=> idx.(1:length(u)))
167-
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x)(f)
169+
f = Rewriters.Prewalk(x->haskey(dict, x) ? dict[x] : x; similarterm=simterm)(f)
168170
lp = linearity_propagator(f)
169171
_sparse(lp, length(u))
170172
end

src/linearity.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,14 @@ function Base.:(==)(comb1::TermCombination, comb2::TermCombination)
7878
end
7979
=#
8080

81-
# we don't care about the ordering in this case
82-
SymbolicUtils.:<(comb1::TermCombination, comb2::TermCombination) = true
8381
# to make Mul and Add work
8482
Base.:*(::Number, comb::TermCombination) = comb
85-
Base.:^(comb::TermCombination, ::Number) = comb
83+
function Base.:^(comb::TermCombination, ::Number)
84+
isone(comb) && return comb
85+
iszero(comb) && return _scalar
86+
return comb * comb
87+
end
88+
8689
function Base.:+(comb1::TermCombination, comb2::TermCombination)
8790
if isone(comb1) && !iszero(comb2)
8891
return comb2

test/direct.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ reference_hes = ModelingToolkit.hessian(rr, X)
6060

6161
sp_hess = ModelingToolkit.sparsehessian(rr, X)
6262
@test findnz(sparse(reference_hes))[1:2] == findnz(sp_hess)[1:2]
63-
# @test isequal(map(spoly, findnz(sparse(reference_hes))[3]), map(spoly, findnz(sp_hess)[3]))
63+
@test isequal(map(spoly, findnz(sparse(reference_hes))[3]), map(spoly, findnz(sp_hess)[3]))
6464

6565
Joop, Jiip = eval.(ModelingToolkit.build_function(∂,[x,y,z],[σ,ρ,β],t))
6666
J = Joop([1.0,2.0,3.0],[1.0,2.0,3.0],1.0)

0 commit comments

Comments
 (0)