-
-
Notifications
You must be signed in to change notification settings - Fork 99
Closed
Labels
questionFurther information is requestedFurther information is requested
Description
I have a question about this example (https://docs.sciml.ai/Optimization/stable/optimization_packages/optimization/#Train-NN-with-Sophia):
using Optimization, Lux, Zygote, MLUtils, Statistics, Plots, Random, ComponentArrays
x = rand(10000)
y = sin.(x)
data = MLUtils.DataLoader((x, y), batchsize = 100)
# Define the neural network
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(Random.default_rng(), model)
ps_ca = ComponentArray(ps)
smodel = StatefulLuxLayer{true}(model, nothing, st)
function callback(state, l)
state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
return l < 1e-1 ## Terminate if loss is small
end
function loss(ps, data)
ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])]
return sum(abs2, ypred .- data[2])
end
optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)
res = Optimization.solve(prob, Optimization.Sophia(), callback = callback)The example works fine, however, I seem to only be able to run the cost function within the optimization runs. E.g. all of these snippets:
[smodel([data[1][i]], prob.u0)[1] for i in eachindex(data[1])]
[smodel([data[1][i]], res.u)[1] for i in eachindex(data[1])]
data[1]
loss(prob.u0, data)
optf(res.u, data)
prob.f(res.u, data)yields the same
ERROR: MethodError: no method matching getindex(::DataLoader{BatchView{…}, Bool, :serial, Val{…}, Tuple{…}, TaskLocalRNG}, ::Int64)
The function `getindex` exists, but no method is defined for this combination of argument types.
Stacktrace:
[1] loss(ps::ComponentVector{…}, data::DataLoader{…})
@ Main ~/Desktop/Julia Playground/Environment - Other - Lux/lux_playground.jl:61
[2] (::OptimizationFunction{…})(::ComponentVector{…}, ::Vararg{…})
@ SciMLBase ~/.julia/packages/SciMLBase/gwbNe/src/scimlfunctions.jl:4220
[3] top-level scope
error.
I wanted to modify the example for my own application, because of this I cannot really understand what is going on.
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested