Skip to content

Commit 7835c1c

Browse files
committed
Add diff rule on constants
1 parent 2473af1 commit 7835c1c

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/differentials.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ function expand_derivatives(O::Operation)
5757

5858
return O
5959
end
60+
expand_derivatives(::Constant) = Constant(0)
6061
expand_derivatives(x) = x
6162

6263
# Don't specialize on the function here
@@ -95,6 +96,7 @@ sin(x())
9596
```
9697
"""
9798
derivative(O::Operation, idx) = derivative(O.op, (O.args...,), Val(idx))
99+
derivative(O::Constant, _) = Constant(0)
98100

99101
# Pre-defined derivatives
100102
import DiffRules, SpecialFunctions, NaNMath

src/direct.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
function gradient(O::Operation, vars::AbstractVector{Operation}; simplify = true)
1+
function gradient(O::Expression, vars::AbstractVector{<:Expression}; simplify = true)
22
out = [expand_derivatives(Differential(v)(O)) for v in vars]
33
simplify ? simplify_constants.(out) : out
44
end
55

6-
function jacobian(ops::AbstractVector{Operation}, vars::AbstractVector{Operation}; simplify = true)
6+
function jacobian(ops::AbstractVector{<:Expression}, vars::AbstractVector{<:Expression}; simplify = true)
77
out = [expand_derivatives(Differential(v)(O)) for O in ops, v in vars]
88
simplify ? simplify_constants.(out) : out
99
end
1010

11-
function hessian(O::Operation, vars::AbstractVector{Operation}; simplify = true)
11+
function hessian(O::Expression, vars::AbstractVector{<:Expression}; simplify = true)
1212
out = [expand_derivatives(Differential(v2)(Differential(v1)(O))) for v1 in vars, v2 in vars]
1313
simplify ? simplify_constants.(out) : out
1414
end
@@ -25,9 +25,7 @@ function simplified_expr(O::Operation)
2525
return Expr(:call, Symbol(O.op), simplified_expr.(O.args)...)
2626
end
2727

28-
function simplified_expr(c::Constant)
29-
c.value
30-
end
28+
simplified_expr(c::Constant) = c.value
3129

3230
function simplified_expr(eq::Equation)
3331
Expr(:(=), simplified_expr(eq.lhs), simplified_expr(eq.rhs))

0 commit comments

Comments
 (0)