@@ -248,7 +248,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
248
248
return u0, p, true
249
249
end
250
250
251
- initdata:: OverrideInitData = f. initialization_data
251
+ initdata:: OverrideInitData = ChainRulesCore . @ignore_derivatives f. initialization_data
252
252
initprob = initdata. initializeprob
253
253
254
254
if initdata. update_initializeprob! != = nothing
@@ -258,7 +258,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
258
258
initdata. update_initializeprob! (initprob, valp)
259
259
end
260
260
end
261
-
261
+
262
262
if is_trivial_initialization (initdata)
263
263
nlsol = initprob
264
264
success = true
@@ -294,19 +294,14 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
294
294
end
295
295
end
296
296
297
- if initdata. initializeprobmap != = nothing
298
- u02 = initdata. initializeprobmap (nlsol)
297
+ u0 = if initdata. initializeprobmap != = nothing
298
+ initdata. initializeprobmap (nlsol)
299
299
end
300
- if initdata. initializeprobpmap != = nothing
301
- p2 = initdata. initializeprobpmap (valp, nlsol)
300
+ p = if initdata. initializeprobpmap != = nothing
301
+ initdata. initializeprobpmap (valp, nlsol)
302
302
end
303
303
304
- # specifically needs to be written this way for Zygote
305
- # See https://github.com/SciML/ModelingToolkit.jl/pull/3585#issuecomment-2883919162
306
- u03 = isnothing (initdata. initializeprobmap) ? u0 : u02
307
- p3 = isnothing (initdata. initializeprobpmap) ? p : p2
308
-
309
- return u03, p3, success
304
+ return u0, p, success
310
305
end
311
306
312
307
"""
0 commit comments