@@ -39,33 +39,17 @@ function Derivative(O::Operation,idx)
39
39
# This calls the Derivative dispatch from the user or pre-defined code
40
40
Derivative (O. op, O. args, Val (idx))
41
41
end
42
+ Derivative (op, args, idx) = Derivative (op, (args... ,), idx)
42
43
43
44
# Pre-defined derivatives
44
45
import DiffRules, SpecialFunctions, NaNMath
45
- for (modu, fun, arity) in DiffRules. diffrules ()
46
- if arity == 1 && ! (fun in (:- , :+ )) # :+ and :- are both unary and binary operators
47
- @eval begin
48
- function Derivative (:: typeof ($ modu.$ fun), arg, :: Type{Val{1}} )
49
- M, f = $ (modu, fun)
50
- @assert length (arg) == 1 " $M .$f is a unary function!"
51
- dx = DiffRules. diffrule (M, f, arg[1 ])
52
- parse (Operation,dx)
53
- end
54
- end
55
- elseif arity == 2
56
- for i in 1 : 2
57
- @eval begin
58
- function Derivative (:: typeof ($ modu.$ fun), args, :: Type{Val{$i}} )
59
- M, f = $ (modu, fun)
60
- if f in (:- , :+ )
61
- @assert length (args) in (1 , 2 ) " $M .$f is a unary or a binary function!"
62
- else
63
- @assert length (args) == 2 " $M .$f is a binary function!"
64
- end
65
- dx = DiffRules. diffrule (M, f, args[1 ], args[2 ])[$ i]
66
- parse (Operation,dx)
67
- end
68
- end
46
+ for (modu, fun, arity) ∈ DiffRules. diffrules ()
47
+ for i ∈ 1 : arity
48
+ @eval function Derivative (:: typeof ($ modu.$ fun), args:: NTuple{$arity,Any} , :: Val{$i} )
49
+ M, f = $ (modu, fun)
50
+ partials = DiffRules. diffrule (M, f, args... )
51
+ dx = @static $ arity == 1 ? partials : partials[$ i]
52
+ parse (Operation,dx)
69
53
end
70
54
end
71
55
end
0 commit comments