diff --git a/examples/usage.jl b/examples/usage.jl index c6838e0f..655d1dd1 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -57,7 +57,7 @@ using DataFrames, MLJBase, Zygote, ADTypes, OptimizationOptimisers df = DataFrame(transpose(r), :auto) model = ICNFModel( icnf; - optimizers = (Lion(),), + optimizers = (OptimiserChain(SignDecay(), WeightDecay(), Adam()),), n_epochs = 300, adtype = AutoZygote(), batch_size = 512, diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 223c2a03..e400adfe 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -13,7 +13,13 @@ end function CondICNFModel( m::AbstractICNF, loss::Function = loss; - optimizers::Tuple = (Optimisers.Lion(),), + optimizers::Tuple = ( + Optimisers.OptimiserChain( + Optimisers.SignDecay(), + Optimisers.WeightDecay(), + Optimisers.Adam(), + ), + ), n_epochs::Int = 300, adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), batch_size::Int = 32, diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 7b31b0b1..91ff7c2c 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -13,7 +13,13 @@ end function ICNFModel( m::AbstractICNF, loss::Function = loss; - optimizers::Tuple = (Optimisers.Lion(),), + optimizers::Tuple = ( + Optimisers.OptimiserChain( + Optimisers.SignDecay(), + Optimisers.WeightDecay(), + Optimisers.Adam(), + ), + ), n_epochs::Int = 300, adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), batch_size::Int = 32,