Skip to content

Commit c28bb73

Browse files
multiargument chain rule
1 parent 1f52f2d commit c28bb73

File tree

5 files changed

+24
-22
lines changed

5 files changed

+24
-22
lines changed

src/operations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ end
66

77
# Recursive ==
88
function Base.:(==)(x::Operation,y::Operation)
9-
x.op == y.op && all(isequal.(x.args,y.args))
9+
x.op == y.op && length(x.args) == length(y.args) && all(isequal.(x.args,y.args))
1010
end
1111
Base.:(==)(x::Operation,y::Number) = false
1212
Base.:(==)(x::Number,y::Operation) = false

src/operators.jl

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,9 @@ Variable(x::Variable,D::Differential) = Variable(x.name,x.value,x.value_type,
2626

2727
function expand_derivatives(O::Operation)
2828
if O.op == Derivative
29-
#=
30-
diff_idxs = find(x->isequal(x,by.x),O.args)
31-
(diff_idxs != nothing || length(diff_idxs) > 1) && error("Derivatives of multi-argument functions require matching a unique argument.")
32-
idx = first(diff_idxs)
33-
=#
34-
i = 1
35-
if typeof(O.args[1].args[i]) == typeof(O.args[2].x) && isequal(O.args[1].args[i],O.args[2].x)
36-
Derivative(O.args[1],i)
37-
else
38-
D = Differential(O.args[2].x)
39-
cr_exp = D*O.args[1].args[i]
40-
Derivative(O.args[1],i) * expand_derivatives(cr_exp)
41-
end
29+
D = O.args[2]
30+
o = O.args[1]
31+
simplify_constants(sum(i->Derivative(o,i)*expand_derivatives(D*o.args[i]),1:length(o.args)))
4232
else
4333
for i in 1:length(O.args)
4434
O.args[i] = expand_derivatives(O.args[i])

src/simplify.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function _simplify_constants(O)
3636
idxs = find(x->typeof(x)<:Variable && isequal(x,Constant(0)),O.args)
3737
_O = Operation(O.op,O.args[1:length(O.args) .∉ (idxs,)])
3838
if isempty(_O.args)
39-
return _O
39+
return Constant(0)
4040
elseif length(_O.args) == 1
4141
return _O.args[1]
4242
else

src/variables.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ function Base.:(==)(x::Variable,y::Variable)
6565
x.value_type == y.value_type && x.diff == y.diff
6666
end
6767

68+
function Base.:(==)(x::Variable,y::Number)
69+
x == Constant(y)
70+
end
71+
72+
function Base.:(==)(x::Number,y::Variable)
73+
Constant(x) == y
74+
end
75+
6876
function Base.Expr(x::Variable)
6977
if x.diff == nothing
7078
return :($(x.name))

test/derivatives.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,22 @@ dpow2 = Derivative(^,[x, y],Val{2})
2121
@test dpow1 == y*x^(y-1)
2222
@test dpow2 == x^y*log(x)
2323

24+
d1 = D*(sin(t)*t)
25+
d2 = D*(sin(t)*cos(t))
26+
@test expand_derivatives(d1) == t*cos(t)+sin(t)
27+
@test expand_derivatives(d2) == cos(t)*cos(t)+sin(t)*-sin(t)
2428

2529
eqs = [0 ~ σ*(y-x),
2630
0 ~ x*-z)-y,
2731
0 ~ x*y - β*z]
2832
sys = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
2933
jac = SciCompDSL.generate_nlsys_jacobian(sys)
30-
@test_broken jac[1,1] == -σ
31-
@test_broken jac[1,2] == σ
32-
@test_broken jac[1,3] == 0
34+
@test jac[1,1] == σ*-1
35+
@test jac[1,2] == σ
36+
@test jac[1,3] == 0
3337
@test jac[2,1] == ρ-z
34-
@test_broken jac[2,2] == -1
35-
@test_broken jac[2,3] == -x
38+
@test jac[2,2] == -1
39+
@test jac[2,3] == x*-1
3640
@test jac[3,1] == y
37-
@test_broken jac[3,2] == x
38-
@test_broken jac[3,3] == -β
41+
@test jac[3,2] == x
42+
@test jac[3,3] == -1*β

0 commit comments

Comments
 (0)