Skip to content

Commit 337beb3

Browse files
authored
Merge pull request #314 from ztangent/typed-gen-fns
Support AD-safe type checking in dynamic @gen functions
2 parents 3662ced + 0b09cdf commit 337beb3

File tree

3 files changed

+28
-23
lines changed

3 files changed

+28
-23
lines changed

src/dsl/dsl.jl

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ function parse_arg(expr)
6161
arg
6262
end
6363

64+
function resolve_grad_arg(arg, __module__)
65+
# Ensure that differentiable arguments are supported by ReverseDiff
66+
if !(DSL_ARG_GRAD_ANNOTATION in arg.annotations) return arg end
67+
typ = Core.eval(__module__, arg.typ)
68+
if typ <: Real
69+
new_typ = :Real
70+
elseif typ <: AbstractArray{<:Real} && IndexStyle(typ) == IndexLinear()
71+
new_typ = :(AbstractArray{<:Real})
72+
elseif Real <: typ || AbstractArray{<:Real} <: typ
73+
new_typ = arg.typ
74+
else
75+
error("Type of $(arg.name)::$(arg.typ) does not support differentiation.")
76+
end
77+
return Argument(arg.name, new_typ, arg.annotations, arg.default)
78+
end
79+
6480
include("dynamic.jl")
6581
include("static.jl")
6682

@@ -124,27 +140,16 @@ end
124140

125141
function parse_gen_function(ast, annotations, __module__)
126142
ast = MacroTools.longdef(ast)
127-
if ast.head != :function
128-
error("syntax error at $ast in $(ast.head)")
129-
end
130-
if length(ast.args) != 2
131-
error("syntax error at $ast in $(ast.args)")
132-
end
133-
signature = ast.args[1]
134-
if signature.head == :(::)
135-
(call_signature, return_type) = signature.args
136-
elseif signature.head == :call
137-
(call_signature, return_type) = (signature, :Any)
138-
else
139-
error("syntax error at $(signature)")
140-
end
141-
body = preprocess_body(ast.args[2], __module__)
142-
name = call_signature.args[1]
143-
args = map(parse_arg, call_signature.args[2:end])
143+
def = MacroTools.splitdef(ast)
144+
name = def[:name]
145+
args = map(parse_arg, def[:args])
146+
body = preprocess_body(def[:body], __module__)
147+
return_type = get(def, :rtype, :Any)
144148
static = DSL_STATIC_ANNOTATION in annotations
145149
if static
146150
make_static_gen_function(name, args, body, return_type, annotations)
147151
else
152+
args = map(a -> resolve_grad_arg(a, __module__), args)
148153
make_dynamic_gen_function(name, args, body, return_type, annotations)
149154
end
150155
end

src/dsl/dynamic.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ const DYNAMIC_DSL_TRACE = Symbol("@trace")
22

33
"Convert Argument structs to ASTs."
44
function arg_to_ast(arg::Argument)
5-
ast = esc(arg.name)
5+
ast = Expr(:(::), esc(arg.name), esc(arg.typ))
66
if (arg.default != nothing)
77
default = something(arg.default)
88
ast = Expr(:kw, ast, esc(default))
99
end
10-
ast
10+
return ast
1111
end
1212

1313
"Escape argument defaults (if present)."
@@ -41,8 +41,8 @@ function make_dynamic_gen_function(name, args, body, return_type, annotations)
4141
esc(body))
4242
arg_types = map((arg) -> esc(arg.typ), args)
4343
arg_defaults = map(escape_default, args)
44-
has_argument_grads = map(
45-
(arg) -> (DSL_ARG_GRAD_ANNOTATION in arg.annotations), args)
44+
has_argument_grads =
45+
map((arg) -> (DSL_ARG_GRAD_ANNOTATION in arg.annotations), args)
4646
accepts_output_grad = DSL_RET_GRAD_ANNOTATION in annotations
4747

4848
quote

test/inference/particle_filter.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ end
3131
# test the hmm_forward_alg on a hand-calculated example
3232
prior = [0.4, 0.6]
3333
emission_dists = [0.1 0.9; 0.7 0.3]'
34-
transition_dists = [0.5 0.5; 0.2 0.8']
34+
transition_dists = [0.5 0.5; 0.2 0.8]'
3535
obs = [2, 1]
3636
expected_marg_lik = 0.
3737
# z = [1, 1]
@@ -114,7 +114,7 @@ end
114114

115115
# do particle filter steps
116116

117-
@gen function step_proposal(prev_trace, T::Int, x::Float64)
117+
@gen function step_proposal(prev_trace, T::Int, x::Int)
118118
@assert T > 1
119119
choices = get_choices(prev_trace)
120120
if T > 2

0 commit comments

Comments
 (0)