Skip to content

Commit 2647a28

Browse files
committed
fix: handling of multi-dimensional arrays
1 parent 91662b6 commit 2647a28

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

docs/src/tutorials/large_systems.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ using BenchmarkTools # for @btime
142142
NewtonRaphson(; autodiff = AutoSparse(AutoForwardDiff(; chunksize = 32)),
143143
linsolve = KLUFactorization()));
144144
@btime solve(prob_brusselator_2d,
145-
NewtonRaphson(; autodiff = AutoSparse(AutoForwardDiff(; chunksize = 32)),
145+
NewtonRaphson(; autodiff = AutoForwardDiff(; chunksize = 32),
146146
linsolve = KrylovJL_GMRES()));
147147
nothing # hide
148148
```

lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,15 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
308308
v_fake = copy(fu)
309309
di_extras = DI.prepare_pullback(fₚ, fu_cache, autodiff, u, v_fake)
310310
return @closure (vJ, v, u, p) -> begin
311-
DI.pullback!(fₚ, fu_cache, reshape(vJ, size(u)), autodiff, u, v, di_extras)
311+
DI.pullback!(fₚ, fu_cache, reshape(vJ, size(u)), autodiff,
312+
u, reshape(v, size(fu_cache)), di_extras)
313+
return
312314
end
313315
else
314316
di_extras = DI.prepare_pullback(fₚ, autodiff, u, fu)
315-
return @closure (v, u, p) -> DI.pullback(fₚ, autodiff, u, v, di_extras)
317+
return @closure (v, u, p) -> begin
318+
return DI.pullback(fₚ, autodiff, u, reshape(v, size(fu)), di_extras)
319+
end
316320
end
317321
end
318322

@@ -351,12 +355,15 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
351355
di_extras = DI.prepare_pushforward(fₚ, fu_cache, autodiff, u, u)
352356
return @closure (Jv, v, u, p) -> begin
353357
DI.pushforward!(
354-
fₚ, fu_cache, reshape(Jv, size(fu_cache)), autodiff, u, v, di_extras)
358+
fₚ, fu_cache, reshape(Jv, size(fu_cache)),
359+
autodiff, u, reshape(v, size(u)), di_extras)
355360
return
356361
end
357362
else
358363
di_extras = DI.prepare_pushforward(fₚ, autodiff, u, u)
359-
return @closure (v, u, p) -> DI.pushforward(fₚ, autodiff, u, v, di_extras)
364+
return @closure (v, u, p) -> begin
365+
return DI.pushforward(fₚ, autodiff, u, reshape(v, size(u)), di_extras)
366+
end
360367
end
361368
end
362369

0 commit comments

Comments
 (0)