Skip to content

Error in trying to use Optimization.jl for LSTM training based on Lux.jlΒ #860

@chooron

Description

@chooron

Describe the bug 🐞

Hello, I want to try using Optimization.jl to perform model optimization based on Lux.jl. Here's my code.

Minimal Reproducible Example πŸ‘‡

using Lux
using Zygote
using StableRNGs
using ComponentArrays
using Optimization
using OptimizationOptimisers

function LSTMCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T,2} where {T}
        x = reshape(x, size(x)..., 1)
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        output = [vec(classifier(y))]
        for x in x_rest
            y, carry = lstm_cell((x, carry))
            output = vcat(output, [vec(classifier(y))])
        end
        @return hcat(output...)
    end
end

model = LSTMCompact(3, 10, 1)
ps, st = Lux.setup(StableRNGs.LehmerRNG(1234), model)
ps_axes = getaxes(ComponentVector(ps))
model_func = (x, ps) -> Lux.apply(model, x, ps, st)
x = rand(3, 10)
y = rand(1, 10)

function object(u, p)
    ps = ComponentVector(u, ps_axes)
    sum((model_func(x, ps)[1] .- y) .^ 2)
end

opt_func = Optimization.OptimizationFunction(object, Optimization.AutoZygote())
opt_prob = Optimization.OptimizationProblem(opt_func, Vector(ComponentVector(ps)))
opt_sol = Optimization.solve(opt_prob, OptimizationOptimisers.Adam(0.1), maxiters=1000)

Error & Stacktrace ⚠️
Translation: The code works when using AutoForwardDiff as the AD type, but when using AutoZygote it encounters the following error:

ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{…}})(::NTuple{9, Vector{…}})

Closest candidates are:
  (::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.NotImplemented) where T
   @ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\projection.jl:121
  (::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.AbstractZero) where T
   @ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\projection.jl:120
  (::ChainRulesCore.ProjectTo{AbstractArray})(::ChainRulesCore.Tangent)
   @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:200
  ...

Stacktrace:
  [1] (::ChainRules.var"#480#485"{ChainRulesCore.ProjectTo{…}, Tuple{…}, ChainRulesCore.Tangent{…}})()
    @ ChainRules D:\Julia\Julia-1.10.4\packages\packages\ChainRules\hShjJ\src\rulesets\Base\array.jl:314
  [2] unthunk
    @ D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\tangent_types\thunks.jl:205 [inlined]
  [3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#479#484"{…}})
    @ ChainRulesCore D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\6Pucz\src\tangent_types\thunks.jl:238
  [4] wrap_chainrules_output
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:110 [inlined]
  [5] map
    @ .\tuple.jl:293 [inlined]
  [6] wrap_chainrules_output
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:111 [inlined]
  [7] (::Zygote.ZBack{ChainRules.var"#vcat_pullback#481"{Tuple{…}, Tuple{…}, Val{…}}})(dy::NTuple{10, Vector{Float64}})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:211
  [8] #21
    @ D:\Julia\Julia-1.10.4\packages\packages\Lux\PsW4M\src\helpers\compact.jl:0 [inlined]
  [9] (::Zygote.Pullback{Tuple{…}, Any})(Ξ”::Tuple{Matrix{…}, Nothing})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [10] CompactLuxLayer
    @ D:\Julia\Julia-1.10.4\packages\packages\Lux\PsW4M\src\helpers\compact.jl:366 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Ξ”::Tuple{Matrix{…}, Nothing})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [12] apply
    @ D:\Julia\Julia-1.10.4\packages\packages\LuxCore\kYVM5\src\LuxCore.jl:171 [inlined]
 [13] #23
    @ e:\JlCode\HydroModels\temp\train_lstm_in_opt.jl:27 [inlined]
 [14] object
    @ e:\JlCode\HydroModels\temp\train_lstm_in_opt.jl:33 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Ξ”::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [16] #291
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [17] #2169#back
    @ D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [18] OptimizationFunction
    @ D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\scimlfunctions.jl:3812 [inlined]
 [19] #291
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [20] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Ξ”::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [21] #37
    @ D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:94 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Ξ”::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [23] #291
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [24] #2169#back
    @ D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [25] #39
    @ D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:97 [inlined]
 [26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Ξ”::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [27] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Ξ”::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:91
 [28] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:148
 [29] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
    @ OptimizationZygoteExt D:\Julia\Julia-1.10.4\packages\packages\OptimizationBase\ni8lU\ext\OptimizationZygoteExt.jl:97
 [30] macro expansion
    @ D:\Julia\Julia-1.10.4\packages\packages\OptimizationOptimisers\AOkbT\src\OptimizationOptimisers.jl:68 [inlined]
 [31] macro expansion
    @ D:\Julia\Julia-1.10.4\packages\packages\Optimization\fPKIF\src\utils.jl:32 [inlined]
 [32] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers D:\Julia\Julia-1.10.4\packages\packages\OptimizationOptimisers\AOkbT\src\OptimizationOptimisers.jl:66
 [33] solve!(cache::OptimizationCache{…})
    @ SciMLBase D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\solve.jl:188
 [34] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…})
    @ SciMLBase D:\Julia\Julia-1.10.4\packages\packages\SciMLBase\nftrI\src\solve.jl:96
 [35] top-level scope
    @ REPL[2]:1
Some type information was truncated. Use `show(err)` to see complete types.

This issue seems to only occur with recurrent neural networks like LSTM, but not with regular fully connected neural networks. So I want to ask if there's a way to optimize Lux.jl's LSTMCell and other RNN models using Optimization.jl

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
  [7d9f7c33] Accessors v0.1.38
βŒƒ [4c88cf16] Aqua v0.8.7
  [6e4b80f9] BenchmarkTools v1.5.0
βŒƒ [336ed68f] CSV v0.10.14
βŒƒ [052768ef] CUDA v5.4.3
  [d360d2e6] ChainRulesCore v1.25.0
βŒƒ [b0b7db55] ComponentArrays v0.15.16
βŒƒ [a93c6f00] DataFrames v1.6.1
βŒƒ [82cc6244] DataInterpolations v6.1.0
βŒ… [459566f4] DiffEqCallbacks v3.7.0
  [ffbed154] DocStringExtensions v0.9.3
βŒƒ [f6369f11] ForwardDiff v0.10.36
βŒƒ [86223c79] Graphs v1.11.2
  [cde335eb] HydroErrors v0.1.0 `D:\Julia\Julia-1.10.4\packages\dev\HydroErrors`
  [de52edbc] Integrals v4.5.0
  [a98d9a8b] Interpolations v0.15.1
  [c8e1da08] IterTools v1.10.0
βŒƒ [7ed4a6bd] LinearSolve v2.34.0
βŒƒ [b2108857] Lux v0.5.65
βŒ… [bb33d45b] LuxCore v0.1.24
βŒƒ [961ee093] ModelingToolkit v9.32.0
βŒƒ [872c559c] NNlib v0.9.22
  [d9ec5142] NamedTupleTools v0.14.3
βŒ… [7f7a1694] Optimization v3.27.0
βŒƒ [3e6eede4] OptimizationBBO v0.3.0
βŒƒ [42dfb2eb] OptimizationOptimisers v0.2.1
βŒƒ [1dea7af3] OrdinaryDiffEq v6.87.0
βŒƒ [d7d3b36b] ParameterSchedulers v0.4.2
βŒƒ [91a5bcdd] Plots v1.40.5
  [92933f4c] ProgressMeter v1.10.2
βŒƒ [731186ca] RecursiveArrayTools v3.27.0
  [189a3867] Reexport v1.2.2
  [7e49a35a] RuntimeGeneratedFunctions v0.5.13
βŒƒ [0bca4576] SciMLBase v2.50.0
βŒƒ [c0aeaf25] SciMLOperators v0.3.11
βŒƒ [1ed8b502] SciMLSensitivity v7.64.0
  [860ef19b] StableRNGs v1.0.2
βŒƒ [90137ffa] StaticArrays v1.9.7
βŒ… [d1185830] SymbolicUtils v2.1.2
βŒ… [0c5d862f] Symbolics v5.36.0
βŒƒ [e88e6eb3] Zygote v0.6.70
  [ade2ca70] Dates
  [37e2e46d] LinearAlgebra
  [9a3f8284] Random
  [2f01184e] SparseArrays v1.10.0
  [10745b16] Statistics v1.10.0
  [fa267f1f] TOML v1.0.3
  • Output of versioninfo()
Julia Version 1.10.4
Commit 48d4fd4843 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 24 Γ— 12th Gen Intel(R) Core(TM) i9-12900HX
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, alderlake)
Threads: 1 default, 0 interactive, 1 GC (on 24 virtual cores)
Environment:
  JULIA_DEPOT_PATH = D:\Julia\Julia-1.10.4\packages
  JULIA_PKG_SERVER = https://mirrors.pku.edu.cn/julia/
  JULIA_EDITOR = code
  JULIA_NUM_THREADS =

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions