@@ -78,3 +78,87 @@ Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunable
7878 sol = solve (newprob, Tsit5 ())
7979 return sum (sol. u[end ])
8080end
81+
82+ using OrdinaryDiffEq
83+ using StableRNGs, Lux
84+ using ComponentArrays
85+ using SciMLSensitivity
86+ import SciMLStructures as SS
87+ using Zygote
88+ using Parameters
89+ using ADTypes
90+ using Test
91+
92+ mutable struct myparam{M,P,S}
93+ model:: M
94+ ps :: P
95+ st :: S
96+ α :: Float64
97+ β :: Float64
98+ γ :: Float64
99+ end
100+
101+ SS. isscimlstructure (:: myparam ) = true
102+ SS. ismutablescimlstructure (:: myparam ) = true
103+ SS. hasportion (:: SS.Tunable , :: myparam ) = true
104+ function SS. canonicalize (:: SS.Tunable , p:: myparam )
105+ buffer = copy (p. ps)
106+ repack = let p = p
107+ function repack (newbuffer)
108+ SS. replace (SS. Tunable (), p, newbuffer)
109+ end
110+ end
111+ return buffer, repack, false
112+ end
113+ function SS. replace (:: SS.Tunable , p:: myparam , newbuffer)
114+ return myparam (p. model, newbuffer, p. st, p. α, p. β, p. γ)
115+ end
116+ function SS. replace! (:: SS.Tunable , p:: myparam , newbuffer)
117+ p. ps = newbuffer
118+ return p
119+ end
120+ function initialize ()
121+ # Defining the neural network
122+ U = Lux. Chain (Lux. Dense (3 ,30 ,tanh),Lux. Dense (30 ,30 ,tanh),Lux. Dense (30 ,1 ))
123+ rng = StableRNG (1111 )
124+ _para,st = Lux. setup (rng,U)
125+ _para = ComponentArray (_para)
126+ # Setting the parameters
127+ α = 0.5
128+ β = 0.1
129+ γ = 0.01
130+ return myparam (U,_para,st,α,β,γ)
131+ end
132+ function UDE_model! (du, u, p, t)
133+ # Extracting parameters
134+ Parameters. @unpack model, ps, st, α, β, γ = p
135+ o = model (u,ps, st)[1 ][1 ]
136+ du[1 ] = o * α * u[1 ] + β * u[2 ] + γ * u[3 ]
137+ du[2 ] = - α * u[1 ] + β * u[2 ] - γ * u[3 ]
138+ du[3 ] = α * u[1 ] - β * u[2 ] + γ * u[3 ]
139+ nothing
140+ end
141+
142+ p = initialize ()
143+ function run_diff (ps)
144+ u01 = [1.0 , 0.0 , 0.0 ]
145+ tspan = (0.0 , 10.0 )
146+ prob = ODEProblem (UDE_model!, u01, tspan, ps)
147+ sol = solve (prob, Rosenbrock23 (), saveat = 0.1 )
148+ return sol. u |> last |> sum
149+ end
150+
151+ run_diff (initialize ())
152+ @test ! iszero (Zygote. gradient (run_diff, initialize ())[1 ]. ps)
153+
154+ function run_diff (ps,sensealg)
155+ u01 = [1.0 , 0.0 , 0.0 ]
156+ tspan = (0.0 , 10.0 )
157+ prob = ODEProblem (UDE_model!, u01, tspan, ps)
158+ sol = solve (prob, Rosenbrock23 (), saveat = 0.1 , sensealg= sensealg)
159+ return sol. u |> last |> sum
160+ end
161+
162+ run_diff (initialize ())
163+ @test ! iszero (Zygote. gradient (run_diff, initialize (), GaussAdjoint ())[1 ]. ps)
164+ @test ! iszero (Zygote. gradient (run_diff, initialize (), GaussAdjoint (autojacvec= false ))[1 ]. ps)
0 commit comments