Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions docs/src/tutorials/basic_mnist_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ Construct the Lux Neural Network containing a DEQ layer.

```@example basic_mnist_deq
function construct_model(solver; model_type::Symbol=:deq)
down = Chain(Conv((3, 3), 1 => 64, gelu; stride=1), GroupNorm(64, 64),
Conv((4, 4), 64 => 64; stride=2, pad=1))
down = Chain(
Conv((3, 3), 1 => 64, gelu; stride=1), GroupNorm(64, 64), Conv((4, 4), 64 => 64; stride=2, pad=1))

# The input layer of the DEQ
deq_model = Chain(
Expand All @@ -72,8 +72,7 @@ function construct_model(solver; model_type::Symbol=:deq)
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
linsolve_kwargs=(; maxiters=10), maxiters=10)

classifier = Chain(
GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10))
classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10))

model = Chain(; down, deq, classifier)

Expand Down Expand Up @@ -133,8 +132,9 @@ function train_model(solver, model_type)
@set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5))

for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader)
_, loss, _, tstate = Training.single_train_step!(
AutoZygote(), loss_function, (x, y), tstate)

_, loss,
_, tstate = Training.single_train_step!(AutoZygote(), loss_function, (x, y), tstate)
if i % 10 == 1
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss
end
Expand All @@ -147,8 +147,9 @@ function train_model(solver, model_type)

for epoch in 1:3
for (i, (x, y)) in enumerate(train_dataloader)
_, loss, _, tstate = Training.single_train_step!(
AutoZygote(), loss_function, (x, y), tstate)
_, loss,
_,
tstate = Training.single_train_step!(AutoZygote(), loss_function, (x, y), tstate)
if i % 10 == 1
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss
end
Expand Down
25 changes: 13 additions & 12 deletions docs/src/tutorials/reduced_dim_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,22 @@ function construct_model(solver; model_type::Symbol=:regdeq)
# The input layer of the DEQ
deq_model = Chain(
Parallel(+,
Dense(
128 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)), # Reduced dim of `128`
Dense(
512 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01))), # Original dim of `512`
Dense(128 => 64, tanh; use_bias=false, init_weight=truncated_normal(;
std=0.01)), # Reduced dim of `128`
Dense(512 => 64, tanh; use_bias=false, init_weight=truncated_normal(;
std=0.01))), # Original dim of `512`
Dense(64 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)),
Dense(64 => 128; use_bias=false, init_weight=truncated_normal(; std=0.01))) # Return the reduced dim of `128`

if model_type === :skipdeq
init = Dense(
512 => 128, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01))
init = Dense(512 => 128, tanh; use_bias=false, init_weight=truncated_normal(;
std=0.01))
elseif model_type === :regdeq
error(":regdeq is not supported for reduced dim models")
else
# This should preferably done via `ChainRulesCore.@ignore_derivatives`. But here
# we are only using Zygote so this is fine.
init = WrappedFunction(x -> Zygote.@ignore(fill!(
similar(x, 128, size(x, 2)), false)))
init = WrappedFunction(x -> Zygote.@ignore(fill!(similar(x, 128, size(x, 2)), false)))
end

deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
Expand Down Expand Up @@ -128,8 +127,9 @@ function train_model(solver, model_type)
@set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5))

for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader)
_, loss, _, tstate = Training.single_train_step!(
AutoZygote(), loss_function, (x, y), tstate)

_, loss,
_, tstate = Training.single_train_step!(AutoZygote(), loss_function, (x, y), tstate)
if i % 10 == 1
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss
end
Expand All @@ -142,8 +142,9 @@ function train_model(solver, model_type)

for epoch in 1:3
for (i, (x, y)) in enumerate(train_dataloader)
_, loss, _, tstate = Training.single_train_step!(
AutoZygote(), loss_function, (x, y), tstate)
_, loss,
_,
tstate = Training.single_train_step!(AutoZygote(), loss_function, (x, y), tstate)
if i % 10 == 1
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss
end
Expand Down
21 changes: 9 additions & 12 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ end

function Base.show(io::IO, sol::DeepEquilibriumSolution)
println(io, "DeepEquilibriumSolution")
println(io, " * Initial Guess: ",
sprint(print, sol.u0; context=(:compact => true, :limit => true)))
println(io, " * Steady State: ",
sprint(print, sol.z_star; context=(:compact => true, :limit => true)))
println(io, " * Residual: ",
sprint(print, sol.residual; context=(:compact => true, :limit => true)))
println(io, " * Initial Guess: ", sprint(print, sol.u0; context=(
:compact => true, :limit => true)))
println(io, " * Steady State: ", sprint(print, sol.z_star; context=(
:compact => true, :limit => true)))
println(io, " * Residual: ", sprint(print, sol.residual; context=(
:compact => true, :limit => true)))
println(io, " * Jacobian Loss: ",
sprint(print, sol.jacobian_loss; context=(:compact => true, :limit => true)))
print(io, " * NFE: ", sol.nfe)
Expand Down Expand Up @@ -171,8 +171,7 @@ function DeepEquilibriumNetwork(
model, solver; init=missing, jacobian_regularization=nothing,
problem_type::Type=SteadyStateProblem{false}, kwargs...)
if init === missing # Regular DEQ
init = WrappedFunction(Base.Fix1(
zeros_init, LuxOps.getproperty(model, Val(:scales))))
init = WrappedFunction(Base.Fix1(zeros_init, LuxOps.getproperty(model, Val(:scales))))
elseif init === nothing # SkipRegDEQ
init = NoOpLayer()
elseif !(init isa AbstractLuxLayer)
Expand Down Expand Up @@ -254,8 +253,7 @@ function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Ma
if post_fuse_layer === nothing
model = MultiScaleInputLayer(Chain(l1, l2), split_idxs, scales)
else
model = MultiScaleInputLayer(
Chain(l1, l2, Parallel(nothing, post_fuse_layer...)), split_idxs, scales)
model = MultiScaleInputLayer(Chain(l1, l2, Parallel(nothing, post_fuse_layer...)), split_idxs, scales)
end

return DeepEquilibriumNetwork(model, solver; kwargs...)
Expand Down Expand Up @@ -291,8 +289,7 @@ Same arguments as [`MultiScaleDeepEquilibriumNetwork`](@ref) but sets `problem_t
`ODEProblem{false}`.
"""
function MultiScaleNeuralODE(args...; kwargs...)
return MultiScaleDeepEquilibriumNetwork(
args...; kwargs..., problem_type=ODEProblem{false})
return MultiScaleDeepEquilibriumNetwork(args...; kwargs..., problem_type=ODEProblem{false})
end

## Generate Initial Condition
Expand Down
7 changes: 3 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@generated function split_and_reshape(
x::AbstractMatrix, ::Val{idxs}, ::Val{shapes}) where {idxs, shapes}
@generated function split_and_reshape(x::AbstractMatrix, ::Val{idxs}, ::Val{shapes}) where {
idxs, shapes}
dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)]
varnames = map(_ -> gensym("x_view"), dims)
calls = [:($(varnames[i]) = x[$(dims[i]), :]) for i in eachindex(dims)]
Expand All @@ -15,8 +15,7 @@ function split_and_reshape(y::AbstractMatrix, x)
szs = [prod(size(xᵢ)[1:(end - 1)]) for xᵢ in x]
counters = vcat(0, cumsum(szs)[1:(end - 1)])
# Make the data contiguous
return map((sz, c, xᵢ) -> copy(reshape(view(y, (c + 1):(c + sz), :), size(xᵢ))),
szs, counters, x)
return map((sz, c, xᵢ) -> copy(reshape(view(y, (c + 1):(c + sz), :), size(xᵢ))), szs, counters, x)
end

flatten(x::AbstractVector) = reshape(x, length(x), 1)
Expand Down
19 changes: 9 additions & 10 deletions test/layers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,16 @@ end
jacobian_regularizations = ongpu ? _jacobian_regularizations[1:(end - 1)] :
_jacobian_regularizations

@testset "Solver: $(nameof(typeof(solver))) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in SOLVERS,
mtype in model_type,
jacobian_regularization in jacobian_regularizations
@testset "Solver: $(nameof(typeof(solver))) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in
SOLVERS,
mtype in model_type, jacobian_regularization in jacobian_regularizations

@testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip(
base_models, init_models, x_sizes)
@testset "x_size: $(x_size)" for (base_model, init_model, x_size) in
zip(base_models, init_models, x_sizes)
model = if mtype === :deq
DeepEquilibriumNetwork(base_model, solver; jacobian_regularization)
elseif mtype === :skipdeq
SkipDeepEquilibriumNetwork(
base_model, init_model, solver; jacobian_regularization)
SkipDeepEquilibriumNetwork(base_model, init_model, solver; jacobian_regularization)
elseif mtype === :skipregdeq
SkipDeepEquilibriumNetwork(base_model, solver; jacobian_regularization)
end
Expand Down Expand Up @@ -112,10 +111,10 @@ end

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
@testset "Solver: $(nameof(typeof(solver)))" for solver in SOLVERS,
mtype in model_type,
jacobian_regularization in jacobian_regularizations
mtype in model_type, jacobian_regularization in jacobian_regularizations

@testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip(
@testset "x_size: $(x_size)" for (
main_layer, mapping_layer, init_layer, x_size, scale) in zip(
main_layers, mapping_layers, init_layers, x_sizes, scales)
model = if mtype === :deq
MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,
Expand Down
Loading