@@ -26,14 +26,73 @@ all_indices(arg::SimpleArg) = [arg.i]
2626all_indices (arg:: TransformedArg ) = vcat ([all_indices (a) for a in arg. f_args]. .. )
2727
2828# Evaluate user-facing args to concrete values passed to the base distribution
29- eval_arg (x:: Any , args) = x
30- eval_arg (x:: SimpleArg , args) = typecheck_arg (x, args[x. i])
31- eval_arg (x:: TransformedArg , args) =
32- x. arg_passer (x. orig_f, [eval_arg (a, args) for a in x. f_args]. .. )
29+ eval_arg (base_arg:: Any , args) = base_arg
30+ eval_arg (base_arg:: SimpleArg , args) = typecheck_arg (base_arg, args[base_arg. i])
31+ eval_arg (base_arg:: TransformedArg , args) =
32+ base_arg. arg_passer (base_arg. orig_f, (eval_arg (a, args) for a in base_arg. f_args). .. )
33+
34+ # Evaluate gradients of base distribution args with respect to user-facing args
35+ function eval_arg_gradient (base_arg:: Any , base_type:: Type , args)
36+ grads = map (enumerate (args)) do (i, arg)
37+ if arg isa Real || arg isa AbstractArray && eltype (arg) <: Real
38+ zero (arg) # Base arg is always constant with respect to input args
39+ else
40+ nothing
41+ end
42+ end
43+ return grads
44+ end
45+
46+ function eval_arg_gradient (base_arg:: SimpleArg{T} , base_type:: Type , args) where {T}
47+ grads = map (enumerate (args)) do (i, arg)
48+ if arg isa Real # Base arg is either equal to or unaffected by input arg
49+ i == base_arg. i ? one (arg) : zero (arg)
50+ elseif arg isa AbstractArray && eltype (arg) <: Real
51+ N, V = length (arg), eltype (arg)
52+ i == base_arg. i ? Matrix {V} (LinearAlgebra. I, N, N) : zeros (V, N, N)
53+ else
54+ nothing
55+ end
56+ end
57+ return grads
58+ end
59+
60+ # Compute gradients when base arg is a scalar type
61+ function eval_arg_gradient (base_arg:: TransformedArg , base_type:: Type{<:Real} , args)
62+ splice_arg (arg, i) = [args[1 : i- 1 ]. .. , arg, args[i+ 1 : end ]. .. ]
63+ per_arg_eval (arg, i) = eval_arg (base_arg, splice_arg (arg, i))
64+ grads = map (enumerate (args)) do (i, arg)
65+ if arg isa Real
66+ ReverseDiff. gradient (a -> per_arg_eval (a, i), [arg])[1 ]
67+ elseif arg isa AbstractArray && eltype (arg) <: Real
68+ ReverseDiff. gradient (a -> per_arg_eval (a, i), arg)
69+ else
70+ nothing
71+ end
72+ end
73+ return grads
74+ end
75+
76+ # Compute Jacobians when base arg is an array type
77+ function eval_arg_gradient (base_arg:: TransformedArg , base_type:: Type{<:AbstractArray{<:Real}} , args)
78+ splice_arg (arg, i) = [args[1 : i- 1 ]. .. , arg, args[i+ 1 : end ]. .. ]
79+ per_arg_eval (arg, i) = eval_arg (base_arg, splice_arg (arg, i))
80+ grads = map (enumerate (args)) do (i, arg)
81+ if arg isa Real
82+ ReverseDiff. jacobian (a -> per_arg_eval (a, i), [arg])
83+ elseif arg isa AbstractArray && eltype (arg) <: Real
84+ ReverseDiff. jacobian (a -> per_arg_eval (a, i), arg)
85+ else
86+ nothing
87+ end
88+ end
89+ return grads
90+ end
3391
3492# Type of SimpleArg must match arg, otherwise a MethodError will be thrown
35- typecheck_arg (x:: SimpleArg{T} , arg:: T ) where {T} = arg
36- typecheck_arg (x:: SimpleArg{T} , arg:: ReverseDiff.TrackedReal{T} ) where {T <: Real } = arg
93+ typecheck_arg (base_arg:: SimpleArg{T} , arg:: T ) where {T} = arg
94+ typecheck_arg (base_arg:: SimpleArg{T} , arg:: ReverseDiff.TrackedReal{T} ) where {T <: Real } = arg
95+ typecheck_arg (base_arg:: SimpleArg{T} , arg:: ReverseDiff.TrackedArray{V, D, N, T} ) where {V, D, N, T} = arg
3796
3897# DistWithArgs
3998struct DistWithArgs{T}
@@ -72,25 +131,32 @@ function logpdf_grad(d::CompiledDistWithArgs{T}, x::T, args...) where T
72131 concrete_args = [eval_arg (arg, args) for arg in d. arglist]
73132 base_has_arg_grads = has_argument_grads (d. base)
74133 base_grads = logpdf_grad (d. base, x, concrete_args... )
75-
76- base_arg_grads = [g for (i, g) in enumerate (base_grads[2 : end ])
77- if base_has_arg_grads[i]]
78- argvec = collect (args)
79- if ! isempty (argvec)
80- eval_arg_grads = [ReverseDiff. gradient (xs -> eval_arg (arg, xs), argvec) for
81- (i, arg) in enumerate (d. arglist) if base_has_arg_grads[i]]
82- eval_arg_grads = reduce (hcat, eval_arg_grads)
83- end
84-
85- retval = [base_grads[1 ]]
86- for i in 1 : d. n_args
87- if self_has_arg_grads[i]
88- push! (retval, eval_arg_grads[i,:]' * base_arg_grads)
89- else
90- push! (retval, nothing )
134+ base_arg_grads = base_grads[2 : end ]
135+
136+ # Set gradient with respect to output
137+ self_output_grad = base_grads[1 ]
138+
139+ # Backpropagate gradients from base arguments to arguments
140+ self_arg_grads = [self_has_arg_grads[i] ? zero (arg) : nothing
141+ for (i, arg) in enumerate (args)]
142+
143+ for (i, base_arg) in enumerate (d. arglist)
144+ base_has_arg_grads[i] || continue
145+ base_grad = base_arg_grads[i]
146+ base_arg_type = typeof (concrete_args[i])
147+ eval_arg_grad = eval_arg_gradient (base_arg, base_arg_type, args)
148+ for (j, g) in enumerate (eval_arg_grad)
149+ (isnothing (g) || ! self_has_arg_grads[j]) && continue
150+ if base_grad isa AbstractArray
151+ increment = reshape (g' * vec (base_grad), size (self_arg_grads[j]))
152+ else
153+ increment = g * base_grad
154+ end
155+ self_arg_grads[j] = self_arg_grads[j] .+ increment
91156 end
92157 end
93- return Tuple (retval)
158+
159+ return (self_output_grad, self_arg_grads... )
94160end
95161
96162function random (d:: CompiledDistWithArgs{T} , args... ):: T where T
0 commit comments