Skip to content

Commit ab27f15

Browse files
committed
Support AD-safe type-checking in dynamic @gen functions.
1 parent d948cf3 commit ab27f15

File tree

2 files changed

+25
-21
lines changed

2 files changed

+25
-21
lines changed

src/dsl/dsl.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ function parse_arg(expr)
6161
arg
6262
end
6363

64+
function resolve_grad_arg(arg, __module__)
65+
# Ensure that differentiable arguments are supported by ReverseDiff
66+
if !(DSL_ARG_GRAD_ANNOTATION in arg.annotations) return arg end
67+
typ = Core.eval(__module__, arg.typ)
68+
if typ <: Real
69+
new_typ = :Real
70+
elseif typ <: AbstractArray{<:Real} && IndexStyle(typ) == IndexLinear()
71+
new_typ = :(AbstractArray{<:Real})
72+
elseif Real <: typ || AbstractArray{<:Real} <: typ
73+
new_typ = arg.typ
74+
else
75+
error("Type of $(arg.name)::$(arg.typ) does not support differentiation.")
76+
end
77+
return Argument(arg.name, new_typ, arg.annotations, arg.default)
78+
end
79+
6480
include("dynamic.jl")
6581
include("static.jl")
6682

@@ -124,23 +140,11 @@ end
124140

125141
function parse_gen_function(ast, annotations, __module__)
126142
ast = MacroTools.longdef(ast)
127-
if ast.head != :function
128-
error("syntax error at $ast in $(ast.head)")
129-
end
130-
if length(ast.args) != 2
131-
error("syntax error at $ast in $(ast.args)")
132-
end
133-
signature = ast.args[1]
134-
if signature.head == :(::)
135-
(call_signature, return_type) = signature.args
136-
elseif signature.head == :call
137-
(call_signature, return_type) = (signature, :Any)
138-
else
139-
error("syntax error at $(signature)")
140-
end
141-
body = preprocess_body(ast.args[2], __module__)
142-
name = call_signature.args[1]
143-
args = map(parse_arg, call_signature.args[2:end])
143+
def = MacroTools.splitdef(ast)
144+
name = def[:name]
145+
args = def[:args] .|> parse_arg .|> a -> resolve_grad_arg(a, __module__)
146+
body = preprocess_body(def[:body], __module__)
147+
return_type = get(def, :rtype, :Any)
144148
static = DSL_STATIC_ANNOTATION in annotations
145149
if static
146150
make_static_gen_function(name, args, body, return_type, annotations)

src/dsl/dynamic.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ const DYNAMIC_DSL_TRACE = Symbol("@trace")
22

33
"Convert Argument structs to ASTs."
44
function arg_to_ast(arg::Argument)
5-
ast = esc(arg.name)
5+
ast = Expr(:(::), esc(arg.name), esc(arg.typ))
66
if (arg.default != nothing)
77
default = something(arg.default)
88
ast = Expr(:kw, ast, esc(default))
99
end
10-
ast
10+
return ast
1111
end
1212

1313
"Escape argument defaults (if present)."
@@ -41,8 +41,8 @@ function make_dynamic_gen_function(name, args, body, return_type, annotations)
4141
esc(body))
4242
arg_types = map((arg) -> esc(arg.typ), args)
4343
arg_defaults = map(escape_default, args)
44-
has_argument_grads = map(
45-
(arg) -> (DSL_ARG_GRAD_ANNOTATION in arg.annotations), args)
44+
has_argument_grads =
45+
map((arg) -> (DSL_ARG_GRAD_ANNOTATION in arg.annotations), args)
4646
accepts_output_grad = DSL_RET_GRAD_ANNOTATION in annotations
4747

4848
quote

0 commit comments

Comments
 (0)