@@ -61,6 +61,22 @@ function parse_arg(expr)
6161 arg
6262end
6363
64+ function resolve_grad_arg (arg, __module__)
65+ # Ensure that differentiable arguments are supported by ReverseDiff
66+ if ! (DSL_ARG_GRAD_ANNOTATION in arg. annotations) return arg end
67+ typ = Core. eval (__module__, arg. typ)
68+ if typ <: Real
69+ new_typ = :Real
70+ elseif typ <: AbstractArray{<:Real} && IndexStyle (typ) == IndexLinear ()
71+ new_typ = :(AbstractArray{<: Real })
72+ elseif Real <: typ || AbstractArray{<: Real } <: typ
73+ new_typ = arg. typ
74+ else
75+ error (" Type of $(arg. name) ::$(arg. typ) does not support differentiation." )
76+ end
77+ return Argument (arg. name, new_typ, arg. annotations, arg. default)
78+ end
79+
6480include (" dynamic.jl" )
6581include (" static.jl" )
6682
@@ -124,27 +140,16 @@ end
124140
125141function parse_gen_function (ast, annotations, __module__)
126142 ast = MacroTools. longdef (ast)
127- if ast. head != :function
128- error (" syntax error at $ast in $(ast. head) " )
129- end
130- if length (ast. args) != 2
131- error (" syntax error at $ast in $(ast. args) " )
132- end
133- signature = ast. args[1 ]
134- if signature. head == :(:: )
135- (call_signature, return_type) = signature. args
136- elseif signature. head == :call
137- (call_signature, return_type) = (signature, :Any )
138- else
139- error (" syntax error at $(signature) " )
140- end
141- body = preprocess_body (ast. args[2 ], __module__)
142- name = call_signature. args[1 ]
143- args = map (parse_arg, call_signature. args[2 : end ])
143+ def = MacroTools. splitdef (ast)
144+ name = def[:name ]
145+ args = map (parse_arg, def[:args ])
146+ body = preprocess_body (def[:body ], __module__)
147+ return_type = get (def, :rtype , :Any )
144148 static = DSL_STATIC_ANNOTATION in annotations
145149 if static
146150 make_static_gen_function (name, args, body, return_type, annotations)
147151 else
152+ args = map (a -> resolve_grad_arg (a, __module__), args)
148153 make_dynamic_gen_function (name, args, body, return_type, annotations)
149154 end
150155end
0 commit comments