@@ -12,12 +12,15 @@ function Adapt.adapt_structure(to, ps::ParamWrapper{P, T}) where {P, T}
1212 adapt (to, ps. data))
1313end
1414
15+ # The reparameterization is adapted from:https://github.com/rtqichen/torchdiffeq/issues/122#issuecomment-738978844
1516@kernel function gpu_kernel (f, du, @Const (u),
1617 @Const (params:: AbstractArray{ParamWrapper{P, T}} ),
1718 @Const (t)) where {P, T}
1819 i = @index (Global, Linear)
1920 @inbounds p = params[i]. params
2021 @inbounds tspan = params[i]. data
22+ # reparameterization t->(t_0, t_f) from t->(0, 1).
23+ t = (tspan[2 ] - tspan[1 ]) * t + tspan[1 ]
2124 @views @inbounds f (du[:, i], u[:, i], p, t)
2225 @inbounds for j in 1 : size (du, 1 )
2326 du[j, i] = du[j, i] * (tspan[2 ] - tspan[1 ])
3033 i = @index (Global, Linear)
3134 @inbounds p = params[i]. params
3235 @inbounds tspan = params[i]. data
33-
36+ # reparameterization
37+ t = (tspan[2 ] - tspan[1 ]) * t + tspan[1 ]
3438 @views @inbounds x = f (u[:, i], p, t)
3539 @inbounds for j in 1 : size (du, 1 )
3640 du[j, i] = x[j] * (tspan[2 ] - tspan[1 ])
6670 @inbounds p = params[i + 1 ]. params
6771 @inbounds tspan = params[i + 1 ]. data
6872
73+ # reparameterization
74+ t = (tspan[2 ] - tspan[1 ]) * t + tspan[1 ]
75+
6976 @views @inbounds f (J[section, section], u[:, i + 1 ], p, t)
7077 @inbounds for j in section, k in section
7178 J[k, j] = J[k, j] * (tspan[2 ] - tspan[1 ])
8188 @inbounds p = params[i + 1 ]. params
8289 @inbounds tspan = params[i + 1 ]. data
8390
91+ # reparameterization
92+ t = (tspan[2 ] - tspan[1 ]) * t + tspan[1 ]
93+
8494 @views @inbounds x = f (u[:, i + 1 ], p, t)
8595
8696 @inbounds for j in section, k in section
150160 @inbounds p = params[i]. params
151161 @inbounds tspan = params[i]. data
152162
163+ # reparameterization
164+ t = (tspan[2 ] - tspan[1 ]) * t + tspan[1 ]
165+
153166 @views @inbounds jac (_W, u[:, i], p, t)
154167
155168 @inbounds for i in eachindex (_W)
187200
188201 _W = @inbounds @view (W[:, :, i])
189202
203+ # reparameterization
204+ t = (tspan[2 ] - tspan[1 ]) * t + tspan[1 ]
205+
190206 @views @inbounds x = jac (u[:, i], p, t)
191207 @inbounds for j in 1 : length (_W)
192208 _W[j] = x[j] * (tspan[2 ] - tspan[1 ])
@@ -217,38 +233,51 @@ end
217233 end
218234end
219235
220- @kernel function Wt_kernel (f:: AbstractArray{T} , W, @Const (u), @Const (p), @Const (gamma),
221- @Const (t)) where {T}
236+ @kernel function Wt_kernel (
237+ jac, W, @Const (u), @Const (params:: AbstractArray{ParamWrapper{P, T}} ),
238+ @Const (gamma), @Const (t)) where {P, T}
222239 i = @index (Global, Linear)
223240 len = size (u, 1 )
241+ @inbounds p = params[i]. params
242+ @inbounds tspan = params[i]. data
243+
244+ # reparameterization
245+ t = (tspan[2 ] - tspan[1 ]) * t + tspan[1 ]
246+
224247 _W = @inbounds @view (W[:, :, i])
225- @inbounds jac = f[i]. tgrad
226- @views @inbounds jac (_W, u[:, i], p[:, i], t)
248+ @views @inbounds jac (_W, u[:, i], p, t)
227249 @inbounds for i in 1 : len
228- _W[i, i] = - inv (gamma) + _W[i, i]
250+ _W[i, i] = - inv (gamma) + _W[i, i] * (tspan[ 2 ] - tspan[ 1 ])
229251 end
230252end
231253
232- @kernel function Wt_kernel (jac, W, @Const (u), @Const (p), @Const (gamma), @Const (t))
254+ @kernel function Wt_kernel_oop (
255+ jac, W, @Const (u), @Const (params:: AbstractArray{ParamWrapper{P, T}} ),
256+ @Const (gamma), @Const (t)) where {P, T}
233257 i = @index (Global, Linear)
234258 len = size (u, 1 )
259+
260+ @inbounds p = params[i]. params
261+ @inbounds tspan = params[i]. data
262+
263+ # reparameterization
264+ t = (tspan[2 ] - tspan[1 ]) * t + tspan[1 ]
265+
235266 _W = @inbounds @view (W[:, :, i])
236- @views @inbounds jac (_W, u[:, i], p[:, i], t)
267+ @views @inbounds x = jac (u[:, i], p, t)
268+ @inbounds for j in 1 : length (_W)
269+ _W[j] = x[j] * (tspan[2 ] - tspan[1 ])
270+ end
237271 @inbounds for i in 1 : len
238272 _W[i, i] = - inv (gamma) + _W[i, i]
239273 end
240274end
241275
242- @kernel function Wt_kernel_oop (f:: AbstractArray{T} , W, @Const (u), @Const (p), @Const (gamma),
243- @Const (t)) where {T}
276+ @kernel function Wt_kernel (jac, W, @Const (u), @Const (p), @Const (gamma), @Const (t))
244277 i = @index (Global, Linear)
245278 len = size (u, 1 )
246279 _W = @inbounds @view (W[:, :, i])
247- @inbounds jac = f[i]. tgrad
248- @views @inbounds x = jac (u[:, i], p[:, i], t)
249- @inbounds for j in 1 : length (_W)
250- _W[j] = x[j]
251- end
280+ @views @inbounds jac (_W, u[:, i], p[:, i], t)
252281 @inbounds for i in 1 : len
253282 _W[i, i] = - inv (gamma) + _W[i, i]
254283 end
277306 @views @inbounds f (du[:, i], u[:, i], p[i], t)
278307 end
279308end
280-
281309@kernel function gpu_kernel_oop_tgrad (f:: AbstractArray{T} , du, @Const (u), @Const (p),
282310 @Const (t)) where {T}
283311 i = @index (Global, Linear)
0 commit comments