Skip to content

Conversation

DhairyaLGandhi
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi commented Apr 8, 2022

Currently, the flat vector is stored inside the Restructure struct. It therefore assumes that incoming parameters also match the same eltype of the model. This will fail in mixed-mode AD, using struct-of-arrays, etc. To combat that, try to reallocate a buffer that can hold in the actual new parameters properly. Note that with mixed-precision, we also pay for conversion (and therefore allocation) with every operation. cc @ChrisRackauckas MWE:

using DiffEqFlux, OrdinaryDiffEq, Test

u0 = Float32[2.0; 0.0]
             datasize = 30
             tspan = (0.0f0, 1.5f0)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end
t = range(tspan[1], tspan[2], length=datasize)
prob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat=t))

model = Chain(x -> x .^ 3,
              Dense(2, 50, tanh),
              Dense(50, 2))
neuralde = NeuralODE(model, tspan, Rodas5(), saveat=t, reltol=1e-7, abstol=1e-9)

function predict_n_ode()
  neuralde(u0)
end
loss_n_ode() = sum(abs2, ode_data .- predict_n_ode())

data = Iterators.repeated((), 10)
opt = ADAM(0.1)
cb = function () #callback function to observe training
   display(loss_n_ode())
end

# Display the ODE with the initial parameter values.
cb()

neuralde = NeuralODE(model, tspan, Rodas5(), saveat=t, reltol=1e-7, abstol=1e-9)
ps = Flux.params(neuralde)
loss1 = loss_n_ode()
Flux.train!(loss_n_ode, ps, data, opt, cb=cb)

It might be good to update the actual struct else it would lie about the actual contents of the parameters. This would mean making Restructure mutable.

This still needs tests before merging; and some test failures are expected since we have to accumulate the gradients properly still

@DhairyaLGandhi
Copy link
Member Author

Of course doing it out of place is less efficient, but seems like at least some of the tests show that there was indeed some cases of implicit conversion going on. It might also be the more correct implementation since we can't assume that the types of the primal and the pullback would match always.

@ChrisRackauckas
Copy link
Member

Fixes SciML/DiffEqFlux.jl#699

@CarloLucibello
Copy link
Member

Solved in #66

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants