@@ -108,8 +108,8 @@ function GenericForwardDiffADJtprod(
108108)
109109 return GenericForwardDiffADJtprod ()
110110end
111- function Jtprod! (:: GenericForwardDiffADJtprod , Jtv:: S , f, x:: S , v:: S , :: Val ) where {S}
112- Jtv .= ForwardDiff. gradient (x -> dot (f (x), v), x):: S
111+ function Jtprod! (:: GenericForwardDiffADJtprod , Jtv, f, x, v, :: Val )
112+ Jtv .= ForwardDiff. gradient (x -> dot (f (x), v), x)
113113 return Jtv
114114end
115115
@@ -138,8 +138,8 @@ function ForwardDiffADJtprod(
138138 c! (cx, x)
139139 dot (cx, u)
140140 end
141- # tagψ = ForwardDiff.Tag(ψ, T)
142- cfg = ForwardDiff. GradientConfig (ψ, temp)
141+ tagψ = ForwardDiff. Tag (ψ, T)
142+ cfg = ForwardDiff. GradientConfig (ψ, temp, ForwardDiff . Chunk (temp), tagψ )
143143
144144 return ForwardDiffADJtprod (cfg, ψ, temp, sol)
145145end
@@ -151,7 +151,7 @@ function Jtprod!(b::ForwardDiffADJtprod{Tag, GT, S}, Jtv, c!, x, v, ::Val) where
151151 b. sol[1 : ncon] .= 0
152152 b. sol[(ncon + 1 ): (ncon + nvar)] .= x
153153 b. sol[(ncon + nvar + 1 ): (2 * ncon + nvar)] .= v
154- ForwardDiff. gradient! (b. temp, b. ψ, b. sol)
154+ ForwardDiff. gradient! (b. temp, b. ψ, b. sol, b . cfg )
155155 Jtv .= view (b. temp, (ncon + 1 ): (nvar + ncon))
156156 return Jtv
157157end
0 commit comments