diff --git a/test/diffeqfluxtests.jl b/test/diffeqfluxtests.jl index 2e5142991..6ec24e2cd 100644 --- a/test/diffeqfluxtests.jl +++ b/test/diffeqfluxtests.jl @@ -70,7 +70,7 @@ ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) dudt2 = Lux.Chain(x -> x .^ 3, Lux.Dense(2, 50, tanh), Lux.Dense(50, 2)) -prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps) +prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-8) pp, st = Lux.setup(rng, dudt2) pp = ComponentArray(pp) @@ -99,13 +99,13 @@ prob = Optimization.OptimizationProblem(optprob, pp) result_neuralode = Optimization.solve(prob, OptimizationOptimisers.ADAM(), callback = callback, - maxiters = 300) + maxiters = 1000) @test result_neuralode.objective≈loss_neuralode(result_neuralode.u)[1] rtol=1e-2 prob2 = remake(prob, u0 = result_neuralode.u) result_neuralode2 = Optimization.solve(prob2, BFGS(initial_stepnorm = 0.0001), callback = callback, - maxiters = 100) + maxiters = 300, allow_f_increases = true) @test result_neuralode2.objective≈loss_neuralode(result_neuralode2.u)[1] rtol=1e-2 @test result_neuralode2.objective < 10