@@ -209,29 +209,18 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
209
209
end
210
210
211
211
# Generic Handling of Krylov Methods for Normal Form Linear Solves
212
- # FIXME : Use MaybeInplace here for efficient matmuls
213
- function __update_JᵀJ! (iip:: Val , cache, sym:: Symbol , J)
214
- return __update_JᵀJ! (iip, cache, sym, getproperty (cache, sym), J)
212
+ function __update_JᵀJ! (cache:: AbstractNonlinearSolveCache )
213
+ if ! (cache. JᵀJ isa KrylovJᵀJ)
214
+ @bb cache. JᵀJ = transpose (cache. J) × cache. J
215
+ end
215
216
end
216
- __update_JᵀJ! (:: Val{false} , cache, sym:: Symbol , _, J) = setproperty! (cache, sym, J' * J)
217
- __update_JᵀJ! (:: Val{true} , cache, sym:: Symbol , _, J) = mul! (getproperty (cache, sym), J' , J)
218
- __update_JᵀJ! (:: Val{false} , cache, sym:: Symbol , H:: KrylovJ ᵀJ, J) = H
219
- __update_JᵀJ! (:: Val{true} , cache, sym:: Symbol , H:: KrylovJ ᵀJ, J) = H
220
217
221
- function __update_Jᵀf! (iip:: Val , cache, sym1:: Symbol , sym2:: Symbol , J, fu)
222
- return __update_Jᵀf! (iip, cache, sym1, sym2, getproperty (cache, sym2), J, fu)
223
- end
224
- function __update_Jᵀf! (:: Val{false} , cache, sym1:: Symbol , sym2:: Symbol , _, J, fu)
225
- return setproperty! (cache, sym1, _restructure (getproperty (cache, sym1), J' * fu))
226
- end
227
- function __update_Jᵀf! (:: Val{true} , cache, sym1:: Symbol , sym2:: Symbol , _, J, fu)
228
- return mul! (_vec (getproperty (cache, sym1)), J' , fu)
229
- end
230
- function __update_Jᵀf! (:: Val{false} , cache, sym1:: Symbol , sym2:: Symbol , H:: KrylovJ ᵀJ, J, fu)
231
- return setproperty! (cache, sym1, _restructure (getproperty (cache, sym1), H. Jᵀ * fu))
232
- end
233
- function __update_Jᵀf! (:: Val{true} , cache, sym1:: Symbol , sym2:: Symbol , H:: KrylovJ ᵀJ, J, fu)
234
- return mul! (_vec (getproperty (cache, sym1)), H. Jᵀ, fu)
218
+ function __update_Jᵀf! (cache:: AbstractNonlinearSolveCache )
219
+ if cache. JᵀJ isa KrylovJᵀJ
220
+ @bb cache. Jᵀf = cache. JᵀJ. Jᵀ × cache. fu
221
+ else
222
+ @bb cache. Jᵀf = transpose (cache. J) × vec (cache. fu)
223
+ end
235
224
end
236
225
237
226
# Left-Right Multiplication
0 commit comments