Skip to content

Commit b629a0b

Browse files
committed
Add jacobian_wrt_vars
1 parent 3212f1a commit b629a0b

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

src/utils.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,3 +710,22 @@ function Base.iterate(it::StatefulBFS, queue = (eltype(it)[(0, it.t)]))
710710
end
711711
return (lv, t), queue
712712
end
713+
714+
function jacobian_wrt_vars(pf::F, p, input_idxs, chunk::C) where {F, C}
715+
dualtype = ForwardDiff.Dual{ForwardDiff.Tag{F, eltype(p)},
716+
eltype(p), ForwardDiff.chunksize(chunk)}
717+
p_big = similar(p, dualtype)
718+
copyto!(p_big, p)
719+
p_closure = let pf = pf,
720+
input_idxs = input_idxs,
721+
p_big = p_big
722+
723+
function (p_small_inner)
724+
p_big[input_idxs] .= p_small_inner
725+
pf(p_big)
726+
end
727+
end
728+
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk)
729+
p_small = p[input_idxs]
730+
ForwardDiff.jacobian(p_closure, p_small, cfg)
731+
end

0 commit comments

Comments
 (0)