Skip to content

Commit b7a88c2

Browse files
Clean up Derivative generation from DiffRules
1 parent f7f45e2 commit b7a88c2

File tree

1 file changed

+8
-24
lines changed

1 file changed

+8
-24
lines changed

src/operators.jl

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,33 +39,17 @@ function Derivative(O::Operation,idx)
3939
# This calls the Derivative dispatch from the user or pre-defined code
4040
Derivative(O.op, O.args, Val(idx))
4141
end
42+
Derivative(op, args, idx) = Derivative(op, (args...,), idx)
4243

4344
# Pre-defined derivatives
4445
import DiffRules, SpecialFunctions, NaNMath
45-
for (modu, fun, arity) in DiffRules.diffrules()
46-
if arity == 1 && !(fun in (:-, :+)) # :+ and :- are both unary and binary operators
47-
@eval begin
48-
function Derivative(::typeof($modu.$fun), arg, ::Type{Val{1}})
49-
M, f = $(modu, fun)
50-
@assert length(arg) == 1 "$M.$f is a unary function!"
51-
dx = DiffRules.diffrule(M, f, arg[1])
52-
parse(Operation,dx)
53-
end
54-
end
55-
elseif arity == 2
56-
for i in 1:2
57-
@eval begin
58-
function Derivative(::typeof($modu.$fun), args, ::Type{Val{$i}})
59-
M, f = $(modu, fun)
60-
if f in (:-, :+)
61-
@assert length(args) in (1, 2) "$M.$f is a unary or a binary function!"
62-
else
63-
@assert length(args) == 2 "$M.$f is a binary function!"
64-
end
65-
dx = DiffRules.diffrule(M, f, args[1], args[2])[$i]
66-
parse(Operation,dx)
67-
end
68-
end
46+
for (modu, fun, arity) DiffRules.diffrules()
47+
for i 1:arity
48+
@eval function Derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i})
49+
M, f = $(modu, fun)
50+
partials = DiffRules.diffrule(M, f, args...)
51+
dx = @static $arity == 1 ? partials : partials[$i]
52+
parse(Operation,dx)
6953
end
7054
end
7155
end

0 commit comments

Comments
 (0)