@@ -308,11 +308,15 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
308
308
v_fake = copy (fu)
309
309
di_extras = DI. prepare_pullback (fₚ, fu_cache, autodiff, u, v_fake)
310
310
return @closure (vJ, v, u, p) -> begin
311
- DI. pullback! (fₚ, fu_cache, reshape (vJ, size (u)), autodiff, u, v, di_extras)
311
+ DI. pullback! (fₚ, fu_cache, reshape (vJ, size (u)), autodiff,
312
+ u, reshape (v, size (fu_cache)), di_extras)
313
+ return
312
314
end
313
315
else
314
316
di_extras = DI. prepare_pullback (fₚ, autodiff, u, fu)
315
- return @closure (v, u, p) -> DI. pullback (fₚ, autodiff, u, v, di_extras)
317
+ return @closure (v, u, p) -> begin
318
+ return DI. pullback (fₚ, autodiff, u, reshape (v, size (fu)), di_extras)
319
+ end
316
320
end
317
321
end
318
322
@@ -351,12 +355,15 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
351
355
di_extras = DI. prepare_pushforward (fₚ, fu_cache, autodiff, u, u)
352
356
return @closure (Jv, v, u, p) -> begin
353
357
DI. pushforward! (
354
- fₚ, fu_cache, reshape (Jv, size (fu_cache)), autodiff, u, v, di_extras)
358
+ fₚ, fu_cache, reshape (Jv, size (fu_cache)),
359
+ autodiff, u, reshape (v, size (u)), di_extras)
355
360
return
356
361
end
357
362
else
358
363
di_extras = DI. prepare_pushforward (fₚ, autodiff, u, u)
359
- return @closure (v, u, p) -> DI. pushforward (fₚ, autodiff, u, v, di_extras)
364
+ return @closure (v, u, p) -> begin
365
+ return DI. pushforward (fₚ, autodiff, u, reshape (v, size (u)), di_extras)
366
+ end
360
367
end
361
368
end
362
369
0 commit comments