diff --git a/docs/src/tutorials/basic_mnist_deq.md b/docs/src/tutorials/basic_mnist_deq.md index 4027c42d..df2d4984 100644 --- a/docs/src/tutorials/basic_mnist_deq.md +++ b/docs/src/tutorials/basic_mnist_deq.md @@ -48,8 +48,8 @@ Construct the Lux Neural Network containing a DEQ layer. ```@example basic_mnist_deq function construct_model(solver; model_type::Symbol=:deq) - down = Chain(Conv((3, 3), 1 => 64, gelu; stride=1), GroupNorm(64, 64), - Conv((4, 4), 64 => 64; stride=2, pad=1)) + down = Chain( + Conv((3, 3), 1 => 64, gelu; stride=1), GroupNorm(64, 64), Conv((4, 4), 64 => 64; stride=2, pad=1)) # The input layer of the DEQ deq_model = Chain( @@ -72,8 +72,7 @@ function construct_model(solver; model_type::Symbol=:deq) deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10), maxiters=10) - classifier = Chain( - GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10)) + classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10)) model = Chain(; down, deq, classifier) @@ -133,8 +132,9 @@ function train_model(solver, model_type) @set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5)) for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader) - _, loss, _, tstate = Training.single_train_step!( - AutoZygote(), loss_function, (x, y), tstate) + + _, loss, + _, tstate = Training.single_train_step!(AutoZygote(), loss_function, (x, y), tstate) if i % 10 == 1 @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss end @@ -147,8 +147,9 @@ function train_model(solver, model_type) for epoch in 1:3 for (i, (x, y)) in enumerate(train_dataloader) - _, loss, _, tstate = Training.single_train_step!( - AutoZygote(), loss_function, (x, y), tstate) + _, loss, + _, + tstate = Training.single_train_step!(AutoZygote(), loss_function, (x, y), tstate) if i % 10 == 1 @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss end diff --git a/docs/src/tutorials/reduced_dim_deq.md b/docs/src/tutorials/reduced_dim_deq.md index df344e08..6ba21d66 100644 --- a/docs/src/tutorials/reduced_dim_deq.md +++ b/docs/src/tutorials/reduced_dim_deq.md @@ -46,23 +46,22 @@ function construct_model(solver; model_type::Symbol=:regdeq) # The input layer of the DEQ deq_model = Chain( Parallel(+, - Dense( - 128 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)), # Reduced dim of `128` - Dense( - 512 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01))), # Original dim of `512` + Dense(128 => 64, tanh; use_bias=false, init_weight=truncated_normal(; + std=0.01)), # Reduced dim of `128` + Dense(512 => 64, tanh; use_bias=false, init_weight=truncated_normal(; + std=0.01))), # Original dim of `512` Dense(64 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)), Dense(64 => 128; use_bias=false, init_weight=truncated_normal(; std=0.01))) # Return the reduced dim of `128` if model_type === :skipdeq - init = Dense( - 512 => 128, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)) + init = Dense(512 => 128, tanh; use_bias=false, init_weight=truncated_normal(; + std=0.01)) elseif model_type === :regdeq error(":regdeq is not supported for reduced dim models") else # This should preferably done via `ChainRulesCore.@ignore_derivatives`. But here # we are only using Zygote so this is fine. - init = WrappedFunction(x -> Zygote.@ignore(fill!( - similar(x, 128, size(x, 2)), false))) + init = WrappedFunction(x -> Zygote.@ignore(fill!(similar(x, 128, size(x, 2)), false))) end deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false, @@ -128,8 +127,9 @@ function train_model(solver, model_type) @set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5)) for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader) - _, loss, _, tstate = Training.single_train_step!( - AutoZygote(), loss_function, (x, y), tstate) + + _, loss, + _, tstate = Training.single_train_step!(AutoZygote(), loss_function, (x, y), tstate) if i % 10 == 1 @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss end @@ -142,8 +142,9 @@ function train_model(solver, model_type) for epoch in 1:3 for (i, (x, y)) in enumerate(train_dataloader) - _, loss, _, tstate = Training.single_train_step!( - AutoZygote(), loss_function, (x, y), tstate) + _, loss, + _, + tstate = Training.single_train_step!(AutoZygote(), loss_function, (x, y), tstate) if i % 10 == 1 @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss end diff --git a/src/layers.jl b/src/layers.jl index 4d12f807..5dbb13bb 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -39,12 +39,12 @@ end function Base.show(io::IO, sol::DeepEquilibriumSolution) println(io, "DeepEquilibriumSolution") - println(io, " * Initial Guess: ", - sprint(print, sol.u0; context=(:compact => true, :limit => true))) - println(io, " * Steady State: ", - sprint(print, sol.z_star; context=(:compact => true, :limit => true))) - println(io, " * Residual: ", - sprint(print, sol.residual; context=(:compact => true, :limit => true))) + println(io, " * Initial Guess: ", sprint(print, sol.u0; context=( + :compact => true, :limit => true))) + println(io, " * Steady State: ", sprint(print, sol.z_star; context=( + :compact => true, :limit => true))) + println(io, " * Residual: ", sprint(print, sol.residual; context=( + :compact => true, :limit => true))) println(io, " * Jacobian Loss: ", sprint(print, sol.jacobian_loss; context=(:compact => true, :limit => true))) print(io, " * NFE: ", sol.nfe) @@ -171,8 +171,7 @@ function DeepEquilibriumNetwork( model, solver; init=missing, jacobian_regularization=nothing, problem_type::Type=SteadyStateProblem{false}, kwargs...) if init === missing # Regular DEQ - init = WrappedFunction(Base.Fix1( - zeros_init, LuxOps.getproperty(model, Val(:scales)))) + init = WrappedFunction(Base.Fix1(zeros_init, LuxOps.getproperty(model, Val(:scales)))) elseif init === nothing # SkipRegDEQ init = NoOpLayer() elseif !(init isa AbstractLuxLayer) @@ -254,8 +253,7 @@ function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Ma if post_fuse_layer === nothing model = MultiScaleInputLayer(Chain(l1, l2), split_idxs, scales) else - model = MultiScaleInputLayer( - Chain(l1, l2, Parallel(nothing, post_fuse_layer...)), split_idxs, scales) + model = MultiScaleInputLayer(Chain(l1, l2, Parallel(nothing, post_fuse_layer...)), split_idxs, scales) end return DeepEquilibriumNetwork(model, solver; kwargs...) @@ -291,8 +289,7 @@ Same arguments as [`MultiScaleDeepEquilibriumNetwork`](@ref) but sets `problem_t `ODEProblem{false}`. """ function MultiScaleNeuralODE(args...; kwargs...) - return MultiScaleDeepEquilibriumNetwork( - args...; kwargs..., problem_type=ODEProblem{false}) + return MultiScaleDeepEquilibriumNetwork(args...; kwargs..., problem_type=ODEProblem{false}) end ## Generate Initial Condition diff --git a/src/utils.jl b/src/utils.jl index e4277a3e..fed5b55d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,5 @@ -@generated function split_and_reshape( - x::AbstractMatrix, ::Val{idxs}, ::Val{shapes}) where {idxs, shapes} +@generated function split_and_reshape(x::AbstractMatrix, ::Val{idxs}, ::Val{shapes}) where { + idxs, shapes} dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)] varnames = map(_ -> gensym("x_view"), dims) calls = [:($(varnames[i]) = x[$(dims[i]), :]) for i in eachindex(dims)] @@ -15,8 +15,7 @@ function split_and_reshape(y::AbstractMatrix, x) szs = [prod(size(xᵢ)[1:(end - 1)]) for xᵢ in x] counters = vcat(0, cumsum(szs)[1:(end - 1)]) # Make the data contiguous - return map((sz, c, xᵢ) -> copy(reshape(view(y, (c + 1):(c + sz), :), size(xᵢ))), - szs, counters, x) + return map((sz, c, xᵢ) -> copy(reshape(view(y, (c + 1):(c + sz), :), size(xᵢ))), szs, counters, x) end flatten(x::AbstractVector) = reshape(x, length(x), 1) diff --git a/test/layers_tests.jl b/test/layers_tests.jl index 127a84f0..6f66b640 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -34,17 +34,16 @@ end jacobian_regularizations = ongpu ? _jacobian_regularizations[1:(end - 1)] : _jacobian_regularizations - @testset "Solver: $(nameof(typeof(solver))) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in SOLVERS, - mtype in model_type, - jacobian_regularization in jacobian_regularizations + @testset "Solver: $(nameof(typeof(solver))) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in + SOLVERS, + mtype in model_type, jacobian_regularization in jacobian_regularizations - @testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip( - base_models, init_models, x_sizes) + @testset "x_size: $(x_size)" for (base_model, init_model, x_size) in + zip(base_models, init_models, x_sizes) model = if mtype === :deq DeepEquilibriumNetwork(base_model, solver; jacobian_regularization) elseif mtype === :skipdeq - SkipDeepEquilibriumNetwork( - base_model, init_model, solver; jacobian_regularization) + SkipDeepEquilibriumNetwork(base_model, init_model, solver; jacobian_regularization) elseif mtype === :skipregdeq SkipDeepEquilibriumNetwork(base_model, solver; jacobian_regularization) end @@ -112,10 +111,10 @@ end @testset "$mode" for (mode, aType, dev, ongpu) in MODES @testset "Solver: $(nameof(typeof(solver)))" for solver in SOLVERS, - mtype in model_type, - jacobian_regularization in jacobian_regularizations + mtype in model_type, jacobian_regularization in jacobian_regularizations - @testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip( + @testset "x_size: $(x_size)" for ( + main_layer, mapping_layer, init_layer, x_size, scale) in zip( main_layers, mapping_layers, init_layers, x_sizes, scales) model = if mtype === :deq MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,