diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b85bfce8..747b614b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -29,8 +29,11 @@ jobs: ${{ runner.os }}-test- ${{ runner.os }}- - name: Install dependencies - run: julia --project=. -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/FluxExperimental.jl"); Pkg.add(url="https://github.com/SciML/DiffEqSensitivity.jl", rev="ap/fastdeq"); Pkg.instantiate()' + # FIXME: Remove once Lux.jl is registered + run: julia --project=. -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/Lux.jl"); Pkg.instantiate()' - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v2 with: - coverage: false \ No newline at end of file + files: lcov.info \ No newline at end of file diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml index fe19c974..456ea6a9 100644 --- a/.github/workflows/Documentation.yml +++ b/.github/workflows/Documentation.yml @@ -4,7 +4,7 @@ on: push: branches: - main - tags: '*' + tags: "*" pull_request: jobs: @@ -14,9 +14,10 @@ jobs: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 with: - version: '1' + version: "1" - name: Install dependencies - run: julia --project=docs -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/FluxExperimental.jl"); Pkg.add(url="https://github.com/SciML/DiffEqSensitivity.jl", rev="ap/fastdeq"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' + # FIXME: Remove once Lux.jl is registered + run: julia --project=docs -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/Lux.jl"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - name: Build and deploy env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token diff --git a/.gitignore b/.gitignore index 3ebae6b8..b346baea 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,9 @@ wandb/ .vscode data/ -/Manifest.toml -build \ No newline at end of file +Manifest.toml +build +statprof +profs +logs +benchmarking \ No newline at end of file diff --git a/Project.toml b/Project.toml index f133446f..d1f70bb1 100644 --- a/Project.toml +++ b/Project.toml @@ -5,19 +5,24 @@ version = "0.1.0" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -FluxExperimental = "c0d22e4d-7f3e-44a4-9c97-37045f84daf2" -FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" @@ -25,14 +30,17 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] CUDA = "3" +ChainRulesCore = "1" DiffEqBase = "6" DiffEqCallbacks = "2.20.1" DiffEqSensitivity = "6.64" -Flux = "0.12" -FluxMPI = "0.1.1" +Functors = "0.2" LinearSolve = "1" +Lux = "0.4" +MLUtils = "0.2" OrdinaryDiffEq = "6" SciMLBase = "1.19" +Setfield = "0.8, 1" SteadyStateDiffEq = "1.6" UnPack = "1" Zygote = "0.6.34" @@ -40,11 +48,10 @@ julia = "1.7" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -FluxExperimental = "c0d22e4d-7f3e-44a4-9c97-37045f84daf2" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["CUDA", "Flux", "FluxExperimental", "LinearAlgebra", "Random", "Test"] \ No newline at end of file +test = ["CUDA", "LinearAlgebra", "Lux", "Random", "Test"] diff --git a/README.md b/README.md index 0e239780..11402f15 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,9 @@ # FastDEQ -![Dynamics Overview](assets/dynamics_overview.gif) [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://fastdeq.sciml.ai/dev/) [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://fastdeq.sciml.ai/stable/) +[![codecov](https://codecov.io/gh/SciML/FastDEQ.jl/branch/main/graph/badge.svg?token=plksEh6pUG)](https://codecov.io/gh/SciML/FastDEQ.jl) +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) + +Deep Equilibrium Networks using [Lux.jl](https://lux.csail.mit.edu/dev) and [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/) diff --git a/docs/make.jl b/docs/make.jl index 475c377e..7a372005 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -14,12 +14,12 @@ makedocs( canonical="https://fastdeq.sciml.ai/stable/"), pages = [ "FastDEQ: Fast Deep Equilibrium Networks" => "index.md", - "API" => [ - "Dynamical Systems" => "api/solvers.md", - "Non Linear Solvers" => "api/nlsolve.md", - "General Purpose Layers" => "api/layers.md", - "DEQ Layers" => "api/deqs.md", - "Miscellaneous" => "api/misc.md", + "Manual" => [ + "Dynamical Systems" => "manual/solvers.md", + "Non Linear Solvers" => "manual/nlsolve.md", + "General Purpose Layers" => "manual/layers.md", + "DEQ Layers" => "manual/deqs.md", + "Miscellaneous" => "manual/misc.md", ], "References" => "references.md", ] diff --git a/docs/src/api/misc.md b/docs/src/api/misc.md deleted file mode 100644 index 00b93abc..00000000 --- a/docs/src/api/misc.md +++ /dev/null @@ -1,10 +0,0 @@ -# Miscellaneous - -```@docs -SteadyStateAdjoint -DeepEquilibriumSolution -get_and_clear_nfe! -compute_deq_jacobian_loss -NormalInitializer -SupervisedLossContainer -``` \ No newline at end of file diff --git a/docs/src/api/solvers.md b/docs/src/api/solvers.md deleted file mode 100644 index 13cd0d05..00000000 --- a/docs/src/api/solvers.md +++ /dev/null @@ -1,15 +0,0 @@ -# Dynamical System Variants - -[baideep2019](@cite) introduced Discrete Deep Equilibrium Models which drives a Discrete Dynamical System to its steady-state. [pal2022mixing](@cite) extends this framework to Continuous Dynamical Systems which converge to the steady-stable in a more stable fashion. For a detailed discussion refer to [pal2022mixing](@cite). - -## Continuous DEQs (Infinite Time Neural ODEs) - -```@docs -ContinuousDEQSolver -``` - -## Discrete DEQs - -```@docs -DiscreteDEQSolver -``` \ No newline at end of file diff --git a/docs/src/index.md b/docs/src/index.md index b45cc7c2..150f7276 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,14 +1,13 @@ # FastDEQ: (Fast) Deep Equlibrium Networks -FastDEQ.jl is a framework built on top of [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/) and [Flux.jl](https://fluxml.ai) enabling the efficient training and inference for Deep Equilibrium Networks (Infinitely Deep Neural Networks). +FastDEQ.jl is a framework built on top of [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/) and [Lux.jl](https://lux.csail.mit.edu/dev/) enabling the efficient training and inference for Deep Equilibrium Networks (Infinitely Deep Neural Networks). ## Installation Currently the package is not registered and requires manually installing a few dependencies. We are working towards upstream fixes which will make installation easier ```julia -] add https://github.com/SciML/DiffEqSensitivity.jl.git#ap/fastdeq -] add https://github.com/avik-pal/FluxExperimental.jl.git#main +] add https://github.com/avik-pal/Lux.jl.git#main ] add https://github.com/SciML/FastDEQ.jl ``` @@ -27,14 +26,4 @@ If you are using this project for research or other academic purposes consider c } ``` -For specific algorithms, check the respective documentations and cite the corresponding papers. - -## FAQs - -#### How do I reproduce the experiments in the paper -- *Mixing Implicit and Explicit Deep Learning with Skip DEQs and Infinite Time Neural ODEs (Continuous DEQs)*? - -Check out the `ap/paper` branch for the code corresponding to that paper. - -#### Are there some tutorials? - -We are working on adding some in the near future. In the meantime, please checkout the `experiments` directory in the `ap/paper` branch. You can also check `test/runtests.jl` for some simple examples. \ No newline at end of file +For specific algorithms, check the respective documentations and cite the corresponding papers. \ No newline at end of file diff --git a/docs/src/api/deqs.md b/docs/src/manual/deqs.md similarity index 100% rename from docs/src/api/deqs.md rename to docs/src/manual/deqs.md diff --git a/docs/src/api/layers.md b/docs/src/manual/layers.md similarity index 64% rename from docs/src/api/layers.md rename to docs/src/manual/layers.md index c9029ca9..cf15fbcf 100644 --- a/docs/src/api/layers.md +++ b/docs/src/manual/layers.md @@ -2,5 +2,4 @@ ```@docs DEQChain -MultiParallelNet ``` \ No newline at end of file diff --git a/docs/src/manual/misc.md b/docs/src/manual/misc.md new file mode 100644 index 00000000..8a3363ed --- /dev/null +++ b/docs/src/manual/misc.md @@ -0,0 +1,7 @@ +# Miscellaneous + +```@docs +DeepEquilibriumAdjoint +DeepEquilibriumSolution +NormalInitializer +``` \ No newline at end of file diff --git a/docs/src/api/nlsolve.md b/docs/src/manual/nlsolve.md similarity index 98% rename from docs/src/api/nlsolve.md rename to docs/src/manual/nlsolve.md index 4140c1cf..28a41d0c 100644 --- a/docs/src/api/nlsolve.md +++ b/docs/src/manual/nlsolve.md @@ -9,4 +9,4 @@ We provide the following NonLinear Solvers for DEQs. These are compatible with G ```@docs BroydenSolver LimitedMemoryBroydenSolver -``` \ No newline at end of file +``` diff --git a/docs/src/manual/solvers.md b/docs/src/manual/solvers.md new file mode 100644 index 00000000..ff0ee4c4 --- /dev/null +++ b/docs/src/manual/solvers.md @@ -0,0 +1,38 @@ +# Dynamical System Variants + +[baideep2019](@cite) introduced Discrete Deep Equilibrium Models which drives a Discrete Dynamical System to its steady-state. [pal2022mixing](@cite) extends this framework to Continuous Dynamical Systems which converge to the steady-stable in a more stable fashion. For a detailed discussion refer to [pal2022mixing](@cite). + +## Continuous DEQs + +```@docs +ContinuousDEQSolver +``` + +## Discrete DEQs + +```@docs +DiscreteDEQSolver +``` + +## Termination Conditions + +#### Termination on Absolute Tolerance + +* `:abs`: Terminates if ``all \left( | \frac{\partial u}{\partial t} | \leq abstol \right)`` +* `:abs_norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq abstol`` +* `:abs_deq_default`: Essentially `abs_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) +* `:abs_deq_best`: Same as `:abs_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged + +#### Termination on Relative Tolerance + +* `:rel`: Terminates if ``all \left(| \frac{\partial u}{\partial t} | \leq reltol \times | u | \right)`` +* `:rel_norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` +* `:rel_deq_default`: Essentially `rel_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) +* `:rel_deq_best`: Same as `:rel_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged + +#### Termination using both Absolute and Relative Tolerances + +* `:norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` & + ``\| \frac{\partial u}{\partial t} \| \leq abstol`` +* `fallback`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. This is usable for small problems + but doesn't scale well for neural networks, and should be avoided unless absolutely necessary \ No newline at end of file diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 00000000..a7d4144a --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,57 @@ +name = "FastDEQExperiments" +uuid = "5aa64bb0-ce80-4310-96b1-36313c344f92" +authors = ["Avik Pal "] +version = "0.1.0" + +[deps] +ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" +Augmentor = "02898b10-1f73-11ea-317c-6393d7073e15" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" +DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" +FastDEQ = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" +Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" +Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLDataPattern = "9920b226-0b2a-5f5f-9153-9aa70a013f8b" +MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7" +Wandb = "ad70616a-06c9-5745-b1f1-6a5f42545108" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +CUDA = "3" +DataLoaders = "0.1" +Flux = "0.13" +FluxMPI = "0.5.3" +Format = "1.3" +Lux = "0.4" +MLDatasets = "0.5" +MLUtils = "0.2" +MPI = "0.19" +NNlib = "0.8" +Optimisers = "0.2" +OrdinaryDiffEq = "6" +ParameterSchedulers = "0.3" +Setfield = "0.8, 1" +Wandb = "0.4.3" +Zygote = "0.6" +julia = "1.6" diff --git a/examples/cifar10/main.jl b/examples/cifar10/main.jl new file mode 100644 index 00000000..1a8f8aea --- /dev/null +++ b/examples/cifar10/main.jl @@ -0,0 +1,407 @@ +# Adapted from https://github.com/avik-pal/Lux.jl/tree/main/examples/ImageNet/main.jl + +using ArgParse # Parse Arguments from Commandline +using DataAugmentation # Image Augmentation +using CUDA # GPUs <3 +using DataLoaders # Pytorch like DataLoaders +using Dates # Printing current time +using FastDEQ # Deep Equilibrium Model +using FastDEQExperiments # Models built using FastDEQ +using FluxMPI # Distibuted Training +using Formatting # Pretty Printing +using Functors # Parameter Manipulation +using Images # Image Processing +using LinearAlgebra # Linear Algebra +using Lux # Neural Network Framework +using MLDataPattern # Data Pattern +using MLDatasets # CIFAR10 +using MLDataUtils # Shuffling and Splitting Data +using MLUtils # Data Processing +using NNlib # Neural Network Backend +using OneHotArrays # One Hot Encoding +using Optimisers # Collection of Gradient Based Optimisers +using ParameterSchedulers # Collection of Schedulers for Parameter Updates +using Random # Make things less Random +using Serialization # Serialize Models +using Setfield # Easy Parameter Manipulation +using Statistics # Statistics +using ValueHistories # Storing Value Histories +using Wandb # Logging to Weights and Biases +using Zygote # Our AD Engine + +# Distributed Training +# FluxMPI.Init(; verbose=true) +CUDA.allowscalar(false) + +# Training Options +include("options.jl") + +function get_experiment_config(args) + return get_experiment_configuration( + Val(:CIFAR10), + Val(Symbol(args["model-size"])); + model_type=Symbol(args["model-type"]), + continuous=!args["discrete"], + abstol=args["abstol"], + reltol=args["reltol"], + jfb=args["jfb"], + train_batchsize=args["train-batchsize"], + eval_batchsize=args["eval-batchsize"], + seed=args["seed"], + w_skip=args["w-skip"], + ) +end + +create_model(expt_config, args) = get_model(expt_config; device=gpu, warmup=true, loss_function=get_loss_function(args)) + +function get_loss_function(args) + if args["model-type"] == "VANILLA" + function loss_function_closure_vanilla(x, y, model, ps, st, w_skip=args["w-skip"]) + (ŷ, soln), st_ = model(x, ps, st) + celoss = logitcrossentropy(ŷ, y) + skiploss = FastDEQExperiments.mae(soln.u₀, soln.z_star) + loss = celoss + return loss, st_, (ŷ, soln.nfe, celoss, skiploss, soln.residual) + end + return loss_function_closure_vanilla + else + function loss_function_closure_skip(x, y, model, ps, st, w_skip=args["w-skip"]) + (ŷ, soln), st_ = model(x, ps, st) + celoss = logitcrossentropy(ŷ, y) + skiploss = FastDEQExperiments.mae(soln.u₀, soln.z_star) + loss = celoss + w_skip * skiploss + return loss, st_, (ŷ, soln.nfe, celoss, skiploss, soln.residual) + end + return loss_function_closure_skip + end +end + +# Checkpointing +function save_checkpoint(state, is_best, filename) + if should_log() + isdir(dirname(filename)) || mkpath(dirname(filename)) + serialize(filename, state) + is_best && cp(filename, joinpath(dirname(filename), "model_best.jls"); force=true) + end +end + +# DataLoading +struct CIFARDataContainer + images + labels + transform +end + +function get_dataloaders(expt_config::NamedTuple) + x_train, y_train = CIFAR10.traindata(Float32) + x_test, y_test = CIFAR10.testdata(Float32) + + x_train_images = map(x -> Image(colorview(RGB, permutedims(x, (3, 2, 1)))), eachslice(x_train; dims=4)) + y_train = collect(eachslice(Float32.(onehotbatch(y_train, 0:9)); dims=2)) + + x_test_images = map(x -> Image(colorview(RGB, permutedims(x, (3, 2, 1)))), eachslice(x_test; dims=4)) + y_test = collect(eachslice(Float32.(onehotbatch(y_test, 0:9)); dims=2)) + + base_transform = ImageToTensor() |> Normalize((0.4914f0, 0.4822f0, 0.4465f0), (0.2023f0, 0.1994f0, 0.2010f0)) + + if expt_config.augment + train_transform = Maybe(FlipX()) |> ScaleKeepAspect((36, 36)) |> RandomResizeCrop((32, 32)) |> base_transform + else + train_transform = base_transform + end + + train_dataset = MLUtils.shuffleobs(CIFARDataContainer(x_train_images, y_train, train_transform)) + train_dataset = is_distributed() ? DistributedDataContainer(train_dataset) : train_dataset + test_dataset = CIFARDataContainer(x_test_images, y_test, base_transform) + test_dataset = is_distributed() ? DistributedDataContainer(test_dataset) : test_dataset + + return ( + DataLoaders.DataLoader(train_dataset, expt_config.train_batchsize), + DataLoaders.DataLoader(test_dataset, expt_config.eval_batchsize), + ) +end + +Base.length(d::CIFARDataContainer) = length(d.images) +Base.getindex(d::CIFARDataContainer, i::Int) = (Array(itemdata(apply(d.transform, d.images[i]))), d.labels[i]) +MLDataPattern.getobs(d::CIFARDataContainer, i::Int64) = MLUtils.getobs(d, i) + +# Validation +function validate(val_loader, model, ps, st, loss_function, args) + batch_time = AverageMeter("Batch Time", "6.3f") + data_time = AverageMeter("Data Time", "6.3f") + losses = AverageMeter("Net Loss", "6.3f") + loss1 = AverageMeter("Cross Entropy Loss", "6.3e") + loss2 = AverageMeter("Skip Loss", "6.3e") + residual = AverageMeter("Residual", "6.3e") + top1 = AverageMeter("Accuracy", "3.2f") + nfe = AverageMeter("NFE", "3.2f") + + progress = ProgressMeter( + length(val_loader), (batch_time, data_time, losses, loss1, loss2, residual, top1, nfe), "Test:" + ) + + st_ = Lux.testmode(st) + t = time() + for (i, (x, y)) in enumerate(CUDA.functional() ? CuIterator(val_loader) : val_loader) + B = size(x, ndims(x)) + data_time(time() - t, B) + + # Compute Output + loss, st_, (ŷ, nfe_, celoss, skiploss, resi) = loss_function(x, y, model, ps, st_) + st_ = Lux.update_state(st_, :update_mask, Val(true)) + + # Measure Elapsed Time + batch_time(time() - t, B) + + # Metrics + acc1 = accuracy(cpu(ŷ), cpu(y)) + top1(acc1, B) + nfe(nfe_, B) + losses(loss, B) + loss1(celoss, B) + loss2(skiploss, B) + residual(norm(resi), B) + + # Print Progress + if i % args["print-freq"] == 0 || i == length(val_loader) + should_log() && print_meter(progress, i) + end + i == length(val_loader) - 1 && invoke_gc() # Needed since the last batch size is different + + t = time() + end + + return ( + batch_time.sum, + data_time.sum, + loss1.sum, + loss2.sum, + losses.sum, + nfe.sum, + top1.sum, + residual.sum, + top1.count, + ) +end + +# Training +function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, loss_function, w_skip, args) + batch_time = AverageMeter("Batch Time", "6.3f") + data_time = AverageMeter("Data Time", "6.3f") + forward_pass_time = AverageMeter("Forward Pass Time", "6.3f") + backward_pass_time = AverageMeter("Backward Pass Time", "6.3f") + losses = AverageMeter("Net Loss", "6.3f") + loss1 = AverageMeter("Cross Entropy Loss", "6.3e") + loss2 = AverageMeter("Skip Loss", "6.3e") + residual = AverageMeter("Residual", "6.3e") + top1 = AverageMeter("Accuracy", "6.2f") + nfe = AverageMeter("NFE", "6.2f") + + progress = ProgressMeter( + length(train_loader), + (batch_time, data_time, forward_pass_time, backward_pass_time, losses, loss1, loss2, residual, top1, nfe), + "Epoch: [$epoch]", + ) + + st = Lux.trainmode(st) + t = time() + for (i, (x, y)) in enumerate(CuIterator(train_loader)) + B = size(x, ndims(x)) + data_time(time() - t, B) + + # Gradients and Update + _t = time() + (loss, st, (ŷ, nfe_, celoss, skiploss, resi)), back = Zygote.pullback( + p -> loss_function(x, y, model, p, st, w_skip), ps + ) + forward_pass_time(time() - _t, B) + _t = time() + gs = back((one(loss), nothing, nothing))[1] + backward_pass_time(time() - _t, B) + st = Lux.update_state(st, :update_mask, Val(true)) + if is_distributed() + gs = allreduce_gradients(gs) + end + optimiser_state, ps = Optimisers.update(optimiser_state, ps, gs) + + # Measure Elapsed Time + batch_time(time() - t, B) + + # Metrics + acc1 = accuracy(cpu(ŷ), cpu(y)) + top1(acc1, B) + nfe(nfe_, B) + losses(loss, B) + loss1(celoss, B) + loss2(skiploss, B) + residual(norm(resi), B) + + # Print Progress + if i % args["print-freq"] == 0 || i == length(train_loader) + should_log() && print_meter(progress, i) + end + i == length(train_loader) - 1 && invoke_gc() # Needed since the last batch size is different + + t = time() + end + + return ( + ps, + st, + optimiser_state, + ( + batch_time.sum, + data_time.sum, + forward_pass_time.sum, + backward_pass_time.sum, + loss1.sum, + loss2.sum, + losses.sum, + nfe.sum, + top1.sum, + residual.sum, + top1.count, + ), + ) +end + +# Main Function +function get_base_experiment_name(args) + return "data-CIFAR10_type-$(args["model-type"])_size-$(args["model-size"])_discrete-$(args["discrete"])_jfb-$(args["jfb"])" +end + +function get_loggable_stats(stats) + v = [stats...] + is_distributed() && MPI.Reduce!(v, +, 0, MPI.COMM_WORLD) + return v[1:end-1] ./ v[end] +end + +function convert_config_to_loggable(expt_config::NamedTuple) + config = Dict() + for (k, v) in pairs(expt_config) + config[k] = isprimitivetype(typeof(v)) ? v : string(v) + end + return config +end + +function main(args) + best_acc1 = 0 + + # Seeding + rng = Random.default_rng() + Random.seed!(rng, args["seed"]) + + # Model Construction + expt_config = get_experiment_config(args) + loggable_config = convert_config_to_loggable(expt_config) + should_log() && println("$(now()) => creating model") + model, ps, st = create_model(expt_config, args) + + should_log() && println("$(now()) => setting up dataloaders") + train_loader, test_loader = get_dataloaders(expt_config) + + # Optimizer and Scheduler + should_log() && println("$(now()) => creating optimiser") + optimiser, scheduler = construct_optimiser(expt_config) + optimiser_state = Optimisers.setup(optimiser, ps) + if is_distributed() + optimiser_state = FluxMPI.synchronize!(optimiser_state) + should_log() && println("$(now()) ==> synced optimiser state across all ranks") + end + + if args["resume"] != "" + if isfile(args["resume"]) + checkpoint = deserialize(args["resume"]) + args["start-epoch"] = checkpoint["epoch"] + optimiser_state = gpu(checkpoint["optimiser_state"]) + ps = gpu(checkpoint["model_parameters"]) + st = gpu(checkpoint["model_states"]) + should_log() && println("$(now()) => loaded checkpoint `$(args["resume"])` (epoch $(args["start-epoch"]))") + else + should_log() && println("$(now()) => no checkpoint found at `$(args["resume"])`. Starting from scratch.") + end + end + + loss_function = get_loss_function(args) + + if args["evaluate"] + validate(test_loader, model, ps, st, loss_function, args) + return nothing + end + + invoke_gc() + + expt_name = get_base_experiment_name(args) + store_in = args["expt-subdir"] == "" ? string(now()) : args["expt-subdir"] + + ckpt_dir = joinpath(args["checkpoint-dir"], expt_name, store_in) + log_path = joinpath(args["log-dir"], expt_name, store_in, "results.csv") + + should_log() && println("$(now()) => checkpoint directory `$(ckpt_dir)`") + + logging_header = ["Epoch", "Train/Batch Time", "Train/Data Time", "Train/Forward Pass Time", "Train/Backward Pass Time", "Train/Cross Entropy Loss", "Train/Skip Loss", "Train/Net Loss", "Train/NFE", "Train/Accuracy", "Train/Residual", "Test/Batch Time", "Test/Data Time", "Test/Cross Entropy Loss", "Test/Skip Loss", "Test/Net Loss", "Test/NFE", "Test/Accuracy", "Test/Residual"] + csv_logger = CSVLogger(log_path, logging_header) + wandb_logger = WandbLogger(project="deep_equilibrium_models", + name=store_in, + config=loggable_config) + + values_to_loggable_dict(args...) = Dict(zip(logging_header, args)) + + should_log() && println("$(now()) => logging results to `$(log_path)`") + + should_log() && serialize(joinpath(dirname(log_path), "setup.jls"), Dict("config" => expt_config, "args" => args)) + + st = hasproperty(expt_config, :pretrain_epochs) && getproperty(expt_config, :pretrain_epochs) > 0 ? Lux.update_state(st, :fixed_depth, Val(getproperty(expt_config, :num_layers))) : st + + wskip_sched = ParameterSchedulers.Exp(args["w-skip"], 0.92f0) + + for epoch in args["start-epoch"]:(expt_config.nepochs) + # Train for 1 epoch + ps, st, optimiser_state, train_stats = train_one_epoch( + train_loader, model, ps, st, optimiser_state, epoch, loss_function, wskip_sched(epoch), args + ) + train_stats = get_loggable_stats(train_stats) + + should_log() && println() + + # Some Housekeeping + invoke_gc() + + # Evaluate on validation set + val_stats = validate(test_loader, model, ps, st, loss_function, args) + val_stats = get_loggable_stats(val_stats) + + should_log() && println() + + csv_logger(epoch, train_stats..., val_stats...) + Wandb.log(wandb_logger, values_to_loggable_dict(epoch, train_stats..., val_stats...)) + should_log() && println("$(now()) => logged intermediated results to csv file\n") + + # ParameterSchedulers + eta_new = ParameterSchedulers.next!(scheduler) + optimiser_state = update_lr(optimiser_state, eta_new) + if hasproperty(expt_config, :pretrain_epochs) && getproperty(expt_config, :pretrain_epochs) == epoch + should_log() && println("$(now()) => pretraining completed\n") + st = Lux.update_state(st, :fixed_depth, Val(0)) + end + + # Some Housekeeping + invoke_gc() + + # Remember Best Accuracy and Save Checkpoint + is_best = val_stats[1] > best_acc1 + best_acc1 = max(val_stats[1], best_acc1) + + save_state = Dict( + "epoch" => epoch, + "config" => loggable_config, + "accuracy" => accuracy, + "model_states" => cpu(st), + "model_parameters" => cpu(ps), + "optimiser_state" => cpu(optimiser_state), + ) + save_checkpoint(save_state, is_best, joinpath(ckpt_dir, "checkpoint.jls")) + end +end + +main(parse_commandline_arguments()) diff --git a/examples/cifar10/options.jl b/examples/cifar10/options.jl new file mode 100644 index 00000000..d5b7fe2a --- /dev/null +++ b/examples/cifar10/options.jl @@ -0,0 +1,77 @@ +using ArgParse + + +# Parse Training Arguments +function parse_commandline_arguments() + parse_settings = ArgParseSettings("FastDEQ CIFAR-10 Training") + + @add_arg_table! parse_settings begin + "--model-size" + default = "TINY" + range_tester = x -> x ∈ ("TINY", "LARGE") + help = "model size: `TINY` or `LARGE`" + "--model-type" + default = "VANILLA" + range_tester = x -> x ∈ ("VANILLA", "SKIP", "SKIPV2") + help = "model type: `VANILLA`, `SKIP` or `SKIPV2`" + "--eval-batchsize" + help = "batch size for evaluation (per process)" + arg_type = Int + default = 32 + "--train-batchsize" + help = "batch size for training (per process)" + arg_type = Int + default = 32 + "--discrete" + help = "use discrete DEQ" + action = :store_true + "--jfb" + help = "enable jacobian-free-backpropagation" + action = :store_true + "--abstol" + default = 0.25f0 + arg_type = Float32 + help = "absolute tolerance for termination" + "--reltol" + default = 0.25f0 + arg_type = Float32 + help = "relative tolerance for termination" + "--w-skip" + default = 1.0f0 + arg_type = Float32 + help = "weight for skip DEQ loss" + "--start-epoch" + help = "manual epoch number (useful on restarts)" + arg_type = Int + default = 1 + "--print-freq" + help = "print frequency" + arg_type = Int + default = 100 + "--resume" + help = "resume from checkpoint" + arg_type = String + default = "" + "--evaluate" + help = "evaluate model on validation set" + action = :store_true + "--seed" + help = "seed for initializing training. " + arg_type = Int + default = 0 + "--checkpoint-dir" + help = "directory to save checkpoints" + arg_type = String + default = "checkpoints/" + "--log-dir" + help = "directory to save logs" + arg_type = String + default = "logs/" + "--expt-subdir" + help = "subdirectory name" + arg_type = String + default = "" + end + + return parse_args(parse_settings) +end \ No newline at end of file diff --git a/examples/imagenet/main.jl b/examples/imagenet/main.jl new file mode 100644 index 00000000..e69de29b diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl new file mode 100644 index 00000000..2230a668 --- /dev/null +++ b/examples/src/FastDEQExperiments.jl @@ -0,0 +1,45 @@ +module FastDEQExperiments + +using CUDA +using Dates +using FastBroadcast +using FastDEQ +using FluxMPI +using Format +using Formatting +using Functors +using Lux +using MPI +using NNlib +using OneHotArrays +using Optimisers +using OrdinaryDiffEq +using ParameterSchedulers +using Random +using Setfield +using Statistics +using Zygote + +import DataLoaders: LearnBase +import MLDataPattern +import MLUtils + +# logging utilities +include("logging.jl") +# get_model_config +include("config.jl") +# get_model +include("models.jl") +# random utilities +include("utils.jl") + +# Exports +export AverageMeter, CSVLogger, ProgressMeter, print_meter + +export get_experiment_configuration + +export construct_optimiser, get_model + +export accuracy, invoke_gc, is_distributed, logitcrossentropy, mae, mse, relieve_gc_pressure, should_log, update_lr + +end \ No newline at end of file diff --git a/examples/src/config.jl b/examples/src/config.jl new file mode 100644 index 00000000..c8a4097f --- /dev/null +++ b/examples/src/config.jl @@ -0,0 +1,212 @@ +function compute_feature_scales(config) + image_size = config.image_size + image_size_downsampled = image_size + for _ in 1:(config.downsample_times) + image_size_downsampled = image_size_downsampled .÷ 2 + end + scales = [(image_size_downsampled..., config.num_channels[1])] + for i in 2:(config.num_branches) + push!(scales, ((scales[end][1:2] .÷ 2)..., config.num_channels[i])) + end + return Tuple(scales) +end + +function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:TINY}) + return ( + num_layers=10, + num_classes=10, + dropout_rate=0.25f0, + group_count=8, + weight_norm=true, + downsample_times=0, + expansion_factor=5, + post_gn_affine=false, + image_size=(32, 32), + num_modules=1, + num_branches=2, + block_type=:basic, + big_kernels=(0, 0), + head_channels=(8, 16), + num_blocks=(1, 1), + num_channels=(24, 24), + fuse_method=:sum, + final_channelsize=200, + fwd_maxiters=18, + bwd_maxiters=20, + continuous=true, + stop_mode=:rel_norm, + nepochs=50, + jfb=false, + augment=false, + model_type=:VANILLA, + abstol=5.0f-2, + reltol=5.0f-2, + ode_solver=VCABM3(), + pretrain_epochs=5, + lr_scheduler=:COSINE, + optimiser=:ADAM, + eta=0.001f0 * scaling_factor(), + ) +end + +function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:LARGE}) + return ( + num_layers=10, + num_classes=10, + dropout_rate=0.3f0, + group_count=8, + weight_norm=true, + downsample_times=0, + expansion_factor=5, + post_gn_affine=false, + image_size=(32, 32), + num_modules=1, + num_branches=4, + block_type=:basic, + big_kernels=(0, 0, 0, 0), + head_channels=(14, 28, 56, 112), + num_blocks=(1, 1, 1, 1), + num_channels=(32, 64, 128, 256), + fuse_method=:sum, + final_channelsize=1680, + fwd_maxiters=18, + bwd_maxiters=20, + continuous=true, + stop_mode=:rel_norm, + nepochs=220, + jfb=false, + augment=true, + model_type=:VANILLA, + abstol=5.0f-2, + reltol=5.0f-2, + ode_solver=VCABM3(), + pretrain_epochs=8, + lr_scheduler=:COSINE, + optimiser=:ADAM, + eta=0.001f0 * scaling_factor(), + ) +end + +function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:SMALL}) + return ( + num_layers=4, + num_classes=1000, + dropout_rate=0.0f0, + group_count=8, + weight_norm=true, + downsample_times=2, + expansion_factor=5, + post_gn_affine=true, + image_size=(224, 224), + num_modules=1, + num_branches=4, + block_type=:basic, + big_kernels=(0, 0, 0, 0), + head_channels=(24, 48, 96, 192), + num_blocks=(1, 1, 1, 1), + num_channels=(32, 64, 128, 256), + fuse_method=:sum, + final_channelsize=2048, + fwd_maxiters=27, + bwd_maxiters=28, + continuous=true, + stop_mode=:rel_norm, + nepochs=100, + jfb=false, + model_type=:VANILLA, + abstol=5.0f-2, + reltol=5.0f-2, + ode_solver=VCABM3(), + pretrain_epochs=18, + lr_scheduler=:COSINE, + optimiser=:SGD, + eta=0.05f0 * scaling_factor(), + weight_decay=0.00005f0, + momentum=0.9f0, + nesterov=true, + ) +end + +function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:LARGE}) + return ( + num_layers=4, + num_classes=1000, + dropout_rate=0.0f0, + group_count=8, + weight_norm=true, + downsample_times=2, + expansion_factor=5, + post_gn_affine=true, + image_size=(224, 224), + num_modules=1, + num_branches=4, + block_type=:basic, + big_kernels=(0, 0, 0, 0), + head_channels=(32, 64, 128, 256), + num_blocks=(1, 1, 1, 1), + num_channels=(80, 160, 320, 640), + fuse_method=:sum, + final_channelsize=2048, + fwd_maxiters=27, + bwd_maxiters=28, + continuous=true, + stop_mode=:rel_norm, + nepochs=100, + jfb=false, + model_type=:VANILLA, + abstol=5.0f-2, + reltol=5.0f-2, + ode_solver=VCABM3(), + pretrain_epochs=18, + lr_scheduler=:COSINE, + optimiser=:SGD, + eta=0.05f0 * scaling_factor(), + weight_decay=0.00005f0, + momentum=0.9f0, + nesterov=true, + ) +end + +function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:XL}) + return ( + num_layers=4, + num_classes=1000, + dropout_rate=0.0f0, + group_count=8, + weight_norm=true, + downsample_times=2, + expansion_factor=5, + post_gn_affine=true, + image_size=(224, 224), + num_modules=1, + num_branches=4, + block_type=:basic, + big_kernels=(0, 0, 0, 0), + head_channels=(32, 64, 128, 256), + num_blocks=(1, 1, 1, 1), + num_channels=(88, 176, 352, 704), + fuse_method=:sum, + final_channelsize=2048, + fwd_maxiters=27, + bwd_maxiters=28, + continuous=true, + stop_mode=:rel_norm, + nepochs=100, + jfb=false, + model_type=:VANILLA, + abstol=5.0f-2, + reltol=5.0f-2, + ode_solver=VCABM3(), + pretrain_epochs=18, + lr_scheduler=:COSINE, + optimiser=:SGD, + eta=0.05f0 * scaling_factor(), + weight_decay=0.00005f0, + momentum=0.9f0, + nesterov=true, + ) +end + +function get_experiment_configuration(dataset::Val, model_size::Val; kwargs...) + return merge(get_default_experiment_configuration(dataset, model_size), kwargs) +end diff --git a/examples/src/logging.jl b/examples/src/logging.jl new file mode 100644 index 00000000..d18abf00 --- /dev/null +++ b/examples/src/logging.jl @@ -0,0 +1,67 @@ +Base.@kwdef mutable struct AverageMeter + fmtstr + val::Float64 = 0.0 + sum::Float64 = 0.0 + count::Int = 0 + average::Float64 = 0 +end + +function AverageMeter(name::String, fmt::String) + fmtstr = Formatting.FormatExpr("$name {1:$fmt} ({2:$fmt})") + return AverageMeter(; fmtstr=fmtstr) +end + +function (meter::AverageMeter)(val, n::Int) + meter.val = val + meter.sum += val * n + meter.count += n + meter.average = meter.sum / meter.count + return meter.average +end + +print_meter(meter::AverageMeter) = Formatting.printfmt(meter.fmtstr, meter.val, meter.average) + +struct ProgressMeter{N} + batch_fmtstr + meters::NTuple{N,AverageMeter} +end + +function ProgressMeter(num_batches::Int, meters::NTuple{N}, prefix::String="") where {N} + fmt = "%" * string(length(string(num_batches))) * "d" + prefix = prefix != "" ? endswith(prefix, " ") ? prefix : prefix * " " : "" + batch_fmtstr = Formatting.generate_formatter("$prefix[$fmt/" * sprintf1(fmt, num_batches) * "]") + return ProgressMeter{N}(batch_fmtstr, meters) +end + +function print_meter(meter::ProgressMeter, batch::Int) + base_str = meter.batch_fmtstr(batch) + print(base_str) + foreach(x -> (print("\t"); print_meter(x)), meter.meters[1:end]) + return println() +end + +struct CSVLogger{N} + filename + fio +end + +function CSVLogger(filename, header) + should_log() && !isdir(dirname(filename)) && mkpath(dirname(filename)) + fio = should_log() ? open(filename, "w") : nothing + N = length(header) + should_log() && println(fio, join(header, ",")) + return CSVLogger{N}(filename, fio) +end + +function (csv::CSVLogger)(args...) + if should_log() + println(csv.fio, join(args, ",")) + flush(csv.fio) + end +end + +function Base.close(csv::CSVLogger) + if should_log() + close(csv.fio) + end +end diff --git a/examples/src/models.jl b/examples/src/models.jl new file mode 100644 index 00000000..356c3d9d --- /dev/null +++ b/examples/src/models.jl @@ -0,0 +1,477 @@ +# Building Blocks +## Helpful Functional Wrappers +function conv1x1(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) + return Conv( + (1, 1), mapping, activation; stride=stride, pad=0, bias=bias, init_weight=NormalInitializer(), kwargs... + ) +end + +function conv3x3(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) + return Conv( + (3, 3), mapping, activation; stride=stride, pad=1, bias=bias, init_weight=NormalInitializer(), kwargs... + ) +end + +function conv5x5(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) + return Conv( + (5, 5), mapping, activation; stride=stride, pad=2, bias=bias, init_weight=NormalInitializer(), kwargs... + ) +end + +addrelu(x, y) = @. relu(x + y) + +reassociate(x::NTuple{2,<:AbstractArray}, y) = (x[1], (x[2], y)) + +addtuple(y) = y[1] .+ y[2] + +## Downsample Module +function downsample_module(mapping, level_diff, activation; group_count=8) + in_channels, out_channels = mapping + + function intermediate_mapping(i) + if in_channels * (2^level_diff) == out_channels + return (in_channels * (2^(i - 1))) => (in_channels * (2^i)) + else + return i == level_diff ? in_channels => out_channels : in_channels => in_channels + end + end + + layers = Lux.AbstractExplicitLayer[] + for i in 1:level_diff + inchs, outchs = intermediate_mapping(i) + push!(layers, conv3x3(inchs => outchs; stride=2)) + # push!(layers, GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + push!(layers, BatchNorm(outchs, activation; affine=true, track_stats=false)) + end + return Chain(layers...) +end + +## Upsample Module +function upsample_module(mapping, level_diff, activation; upsample_mode::Symbol=:nearest, group_count=8) + in_channels, out_channels = mapping + + function intermediate_mapping(i) + if out_channels * (2^level_diff) == in_channels + (in_channels ÷ (2^(i - 1))) => (in_channels ÷ (2^i)) + else + i == level_diff ? in_channels => out_channels : in_channels => in_channels + end + end + + layers = Lux.AbstractExplicitLayer[] + for i in 1:level_diff + inchs, outchs = intermediate_mapping(i) + push!(layers, conv1x1(inchs => outchs)) + # push!(layers, GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + push!(layers, BatchNorm(outchs, activation; affine=true, track_stats=false)) + push!(layers, Upsample(upsample_mode; scale=2)) + end + return Chain(layers...) +end + +## Residual Block +struct ResidualBlock{C1,C2,Dr,Do,N1,N2,N3} <: + Lux.AbstractExplicitContainerLayer{(:conv1, :conv2, :dropout, :downsample, :norm1, :norm2, :norm3)} + conv1::C1 + conv2::C2 + dropout::Dr + downsample::Do + norm1::N1 + norm2::N2 + norm3::N3 +end + +function ResidualBlock( + mapping; + deq_expand::Int=5, + num_gn_groups::Int=4, + downsample=NoOpLayer(), + n_big_kernels::Int=0, + dropout_rate::Real=0.0f0, + gn_affine::Bool=true, + weight_norm::Bool=true, + gn_track_stats::Bool=false, +) + inplanes, outplanes = mapping + inner_planes = outplanes * deq_expand + conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=false) + conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=false) + + conv1, conv2 = if weight_norm + WeightNorm(conv1, (:weight,), (4,)), WeightNorm(conv2, (:weight,), (4,)) + else + conv1, conv2 + end + + # norm1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # norm2 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + # norm3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + norm1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) + norm2 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) + norm3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) + + dropout = VariationalHiddenDropout(dropout_rate) + + return ResidualBlock(conv1, conv2, dropout, downsample, norm1, norm2, norm3) +end + +function (rb::ResidualBlock)((x, y)::Tuple{<:AbstractArray,<:AbstractArray}, ps, st) + x, st_conv1 = rb.conv1(x, ps.conv1, st.conv1) + x, st_norm1 = rb.norm1(x, ps.norm1, st.norm1) + x, st_conv2 = rb.conv2(x, ps.conv2, st.conv2) + + x_do, st_downsample = rb.downsample(x, ps.downsample, st.downsample) + x_dr, st_dropout = rb.dropout(x, ps.dropout, st.dropout) + + y_ = x_dr .+ y + y_, st_norm2 = rb.norm2(y_, ps.norm2, st.norm2) + + y__ = relu.(y_ .+ x_do) + y__, st_norm3 = rb.norm3(y__, ps.norm3, st.norm3) + + return ( + y__, + ( + conv1=st_conv1, + conv2=st_conv2, + dropout=st_dropout, + downsample=st_downsample, + norm1=st_norm1, + norm2=st_norm2, + norm3=st_norm3, + ), + ) +end + +function (rb::ResidualBlock)(x::AbstractArray, ps, st) + x, st_conv1 = rb.conv1(x, ps.conv1, st.conv1) + x, st_norm1 = rb.norm1(x, ps.norm1, st.norm1) + x, st_conv2 = rb.conv2(x, ps.conv2, st.conv2) + + x_do, st_downsample = rb.downsample(x, ps.downsample, st.downsample) + + x_dr, st_dropout = rb.dropout(x, ps.dropout, st.dropout) + x_dr, st_norm2 = rb.norm2(x_dr, ps.norm2, st.norm2) + + y__ = relu.(x_dr .+ x_do) + y__, st_norm3 = rb.norm3(y__, ps.norm3, st.norm3) + + return ( + y__, + ( + conv1=st_conv1, + conv2=st_conv2, + dropout=st_dropout, + downsample=st_downsample, + norm1=st_norm1, + norm2=st_norm2, + norm3=st_norm3, + ), + ) +end + +# Bottleneck Block +struct BottleneckBlock{R,C,M} <: Lux.AbstractExplicitContainerLayer{(:rescale, :conv, :mapping)} + rescale::R + conv::C + mapping::M +end + +function BottleneckBlock(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=true, bn_affine::Bool=true) + rescale = if first(mapping) != last(mapping) * expansion + Chain( + conv1x1(first(mapping) => last(mapping) * expansion), + BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), + ) + else + NoOpLayer() + end + + return BottleneckBlock( + rescale, + conv1x1(mapping), + Chain( + BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), + conv3x3(last(mapping) => last(mapping)), + BatchNorm(last(mapping), relu; track_stats=bn_track_stats, affine=bn_affine), + conv1x1(last(mapping) => last(mapping) * expansion), + BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine) + ) + ) +end + +function (bn::BottleneckBlock)((x, y)::Tuple{<:AbstractArray,<:AbstractArray}, ps, st) + x_r, st_rescale = bn.rescale(x, ps.rescale, st.rescale) + x_m, st_conv1 = bn.conv(x, ps.conv, st.conv) + + x_m = y .+ x_m + x_m, st_mapping = bn.mapping(x_m, ps.mapping, st.mapping) + + return ( + relu.(x_m .+ x_r), + ( + rescale=st_rescale, + conv=st_conv1, + mapping=st_mapping, + ) + ) +end + +function (bn::BottleneckBlock)(x::AbstractArray, ps, st) + x_r, st_rescale = bn.rescale(x, ps.rescale, st.rescale) + x_m, st_conv1 = bn.conv(x, ps.conv, st.conv) + x_m, st_mapping = bn.mapping(x_m, ps.mapping, st.mapping) + + return ( + relu.(x_m .+ x_r), + ( + rescale=st_rescale, + conv=st_conv1, + mapping=st_mapping, + ) + ) +end + +# Dataset Specific Models +function get_model( + config::NamedTuple; + device=gpu, + warmup::Bool=true, # Helps reduce Zygote compile times + loss_function=nothing, +) + @assert !warmup || loss_function !== nothing + + init_channel_size = config.num_channels[1] + + downsample_layers = [ + conv3x3(3 => init_channel_size; stride=config.downsample_times >= 1 ? 2 : 1), + BatchNorm(init_channel_size, relu; affine=true, track_stats=true), + conv3x3(init_channel_size => init_channel_size; stride=config.downsample_times >= 2 ? 2 : 1), + BatchNorm(init_channel_size, relu; affine=true, track_stats=true), + ] + for _ in 3:(config.downsample_times) + append!( + downsample_layers, + [ + conv3x3(init_channel_size => init_channel_size; stride=2), + BatchNorm(init_channel_size, relu; affine=true, track_stats=true), + ], + ) + end + downsample = Chain(downsample_layers...) + + stage0 = if config.downsample_times == 0 && config.num_branches <= 2 + NoOpLayer() + else + Chain( + conv1x1(init_channel_size => init_channel_size; bias=false), + BatchNorm(init_channel_size, relu; affine=true, track_stats=true), + ) + end + + initial_layers = Chain(downsample, stage0) + + main_layers = Tuple( + ResidualBlock( + config.num_channels[i] => config.num_channels[i]; + deq_expand=config.expansion_factor, + dropout_rate=config.dropout_rate, + num_gn_groups=config.group_count, + n_big_kernels=config.big_kernels[i], + ) for i in 1:(config.num_branches) + ) + + mapping_layers = Matrix{Lux.AbstractExplicitLayer}(undef, config.num_branches, config.num_branches) + for i in 1:(config.num_branches) + for j in 1:(config.num_branches) + if i == j + mapping_layers[i, j] = NoOpLayer() + elseif i < j + mapping_layers[i, j] = downsample_module( + config.num_channels[i] => config.num_channels[j], j - i, relu; group_count=config.group_count + ) + else + mapping_layers[i, j] = upsample_module( + config.num_channels[i] => config.num_channels[j], + i - j, + relu; + group_count=config.group_count, + upsample_mode=:nearest, + ) + end + end + end + + post_fuse_layers = Tuple( + Chain( + ActivationFunction(relu), + conv1x1(config.num_channels[i] => config.num_channels[i]), + # GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=false, track_stats=false), + BatchNorm(config.num_channels[i]; affine=false, track_stats=false), + ) for i in 1:(config.num_branches) + ) + + increment_modules = Parallel( + nothing, + [BottleneckBlock(config.num_channels[i] => config.head_channels[i]) for i in 1:(config.num_branches)]..., + ) + + downsample_modules = PairwiseFusion( + config.fuse_method == :sum ? (+) : error("Only `fuse_method` = `:sum` is supported"), + [ + Chain( + conv3x3(config.head_channels[i] * 4 => config.head_channels[i + 1] * 4; stride=2, bias=true), + BatchNorm(config.head_channels[i + 1] * 4, relu; track_stats=true, affine=true), + ) for i in 1:(config.num_branches - 1) + ]..., + ) + + final_layers = Chain( + increment_modules, + downsample_modules, + conv1x1(config.head_channels[config.num_branches] * 4 => config.final_channelsize; bias=true), + BatchNorm(config.final_channelsize, relu; track_stats=true, affine=true), + GlobalMeanPool(), + FlattenLayer(), + Dense(config.final_channelsize, config.num_classes), + ) + + solver = if config.continuous + ContinuousDEQSolver( + config.ode_solver; + mode=config.stop_mode, + abstol=config.abstol, + reltol=config.reltol, + abstol_termination=config.abstol, + reltol_termination=config.reltol, + ) + else + DiscreteDEQSolver( + LimitedMemoryBroydenSolver(); + mode=config.stop_mode, + abstol_termination=config.abstol, + reltol_termination=config.reltol, + ) + end + + sensealg = DeepEquilibriumAdjoint( + config.abstol, config.reltol, config.bwd_maxiters; mode=config.jfb ? :jfb : :vanilla + ) + + deq = if config.model_type ∈ (:SKIP, :SKIPV2) + shortcut = if config.model_type == :SKIP + slayers = Lux.AbstractExplicitLayer[ResidualBlock( + config.num_channels[1] => config.num_channels[1]; weight_norm=true + )] + for i in 1:(config.num_branches - 1) + push!( + slayers, + downsample_module( + config.num_channels[1] => config.num_channels[i + 1], + i, + relu; + group_count=config.group_count, + ), + ) + end + tuple(slayers...) + else + nothing + end + MultiScaleSkipDeepEquilibriumNetwork( + main_layers, + mapping_layers, + post_fuse_layers, + shortcut, + solver, + compute_feature_scales(config); + maxiters=config.fwd_maxiters, + sensealg=sensealg, + verbose=false, + ) + elseif config.model_type == :VANILLA + MultiScaleDeepEquilibriumNetwork( + main_layers, + mapping_layers, + post_fuse_layers, + solver, + compute_feature_scales(config); + maxiters=config.fwd_maxiters, + sensealg=sensealg, + verbose=false, + ) + else + throw(ArgumentError("`model_type` must be one of `[:SKIP, :SKIPV2, :VANILLA]`")) + end + + model = DEQChain(initial_layers, deq, final_layers) + rng = Random.default_rng() + Random.seed!(rng, config.seed) + ps, st = device.(Lux.setup(rng, model)) + + ps, st = if is_distributed() + ps_ = FluxMPI.synchronize!(ps; root_rank=0) + should_log() && println("$(now()) ===> synchronized model parameters across all processes") + st_ = FluxMPI.synchronize!(st; root_rank=0) + should_log() && println("$(now()) ===> synchronized model state across all processes") + ps_, st_ + else + ps, st + end + + if warmup + should_log() && println("$(now()) ==> starting model warmup") + x__ = device(randn(Float32, config.image_size..., 3, 2)) + y__ = device(Float32.(onehotbatch([1, 2], 0:(config.num_classes - 1)))) + model(x__, ps, st) + should_log() && println("$(now()) ==> forward pass warmup completed") + + st_ = Lux.update_state(st, :fixed_depth, Val(2)) + model(x__, ps, st_) + should_log() && println("$(now()) ==> forward pass (pretraining) warmup completed") + + (l, _, _), back = pullback(p -> loss_function(x__, y__, model, p, st), ps) + back((one(l), nothing, nothing)) + should_log() && println("$(now()) ==> backward pass warmup completed") + + (l, _, _), back = pullback(p -> loss_function(x__, y__, model, p, st_), ps) + back((one(l), nothing, nothing)) + should_log() && println("$(now()) ==> backward pass (pretraining) warmup completed") + + invoke_gc() + end + + return model, ps, st +end + +# Optimisers +function construct_optimiser(config::NamedTuple) + opt = if config.optimiser == :ADAM + Optimisers.ADAM(config.eta) + elseif config.optimiser == :SGD + if config.nesterov + Optimisers.Nesterov(config.eta, config.momentum) + else + if iszero(config.momentum) + Optimisers.Descent(config.eta) + else + Optimisers.Momentum(config.eta, config.momentum) + end + end + else + throw(ArgumentError("`config.optimiser` must be either `:ADAM` or `:SGD`")) + end + if hasproperty(config, :weight_decay) && !iszero(config.weight_decay) + opt = Optimisers.OptimiserChain(opt, Optimisers.WeightDecay(config.weight_decay)) + end + + sched = if config.lr_scheduler == :COSINE + ParameterSchedulers.Stateful(ParameterSchedulers.Cos(config.eta, 1.0f-6, config.nepochs)) + elseif config.lr_scheduler == :CONSTANT + ParameterSchedulers.Stateful(ParameterSchedulers.Constant(config.eta)) + else + throw(ArgumentError("`config.lr_scheduler` must be either `:COSINE` or `:CONSTANT`")) + end + + return opt, sched +end diff --git a/examples/src/utils.jl b/examples/src/utils.jl new file mode 100644 index 00000000..e68ac402 --- /dev/null +++ b/examples/src/utils.jl @@ -0,0 +1,55 @@ +# unsafe_free OneHotArrays +CUDA.unsafe_free!(x::OneHotArray) = CUDA.unsafe_free!(x.indices) + +# Memory Management +relieve_gc_pressure(::Union{Nothing,<:AbstractArray}) = nothing +relieve_gc_pressure(x::CuArray) = CUDA.unsafe_free!(x) +relieve_gc_pressure(t::Tuple) = relieve_gc_pressure.(t) +relieve_gc_pressure(x::NamedTuple) = fmap(relieve_gc_pressure, x) + +function invoke_gc() + GC.gc(true) + CUDA.reclaim() + return nothing +end + +# Optimisers / Parameter Schedulers +function update_lr(st::ST, eta) where {ST} + if hasfield(ST, :eta) + @set! st.eta = eta + end + return st +end +update_lr(st::Optimisers.OptimiserChain, eta) = update_lr.(st.opts, eta) +function update_lr(st::Optimisers.Leaf, eta) + @set! st.rule = update_lr(st.rule, eta) +end +update_lr(st_opt::NamedTuple, eta) = fmap(l -> update_lr(l, eta), st_opt) + +# Metrics +accuracy(ŷ, y) = sum(argmax.(eachcol(ŷ)) .== onecold(y)) * 100 / size(y, ndims(y)) + +function accuracy(ŷ, y, topk::NTuple{N,<:Int}) where {N} + maxk = maximum(topk) + + pred_labels = partialsortperm.(eachcol(ŷ), (1:maxk,), rev=true) + true_labels = onecold(y) + + accuracies = Tuple(sum(map((a, b) -> sum(view(a, 1:k) .== b), pred_labels, true_labels)) for k in topk) + + return accuracies .* 100 ./ size(y, ndims(y)) +end + +# Distributed Utils +@inline is_distributed() = FluxMPI.Initialized() && total_workers() > 1 +@inline should_log() = !FluxMPI.Initialized() || local_rank() == 0 +@inline scaling_factor() = (is_distributed() ? total_workers() : 1) + +# Loss Function +@inline logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1)) +@inline mae(ŷ, y) = mean(abs, ŷ .- y) +@inline mse(ŷ, y) = mean(abs2, ŷ .- y) + +# DataLoaders doesn't yet work with MLUtils +MLDataPattern.nobs(x) = MLUtils.numobs(x) +MLDataPattern.getobs(d::MLUtils.ObsView, i::Int64) = MLUtils.getobs(d, i) diff --git a/src/FastDEQ.jl b/src/FastDEQ.jl index 8830f9e7..7384ecb7 100644 --- a/src/FastDEQ.jl +++ b/src/FastDEQ.jl @@ -1,48 +1,53 @@ module FastDEQ -using CUDA, DiffEqBase, DiffEqCallbacks, DiffEqSensitivity, Flux, FluxExperimental, LinearAlgebra, LinearSolve, - OrdinaryDiffEq, SciMLBase, Statistics, SteadyStateDiffEq, UnPack, Zygote - -abstract type AbstractDeepEquilibriumNetwork end - -function Base.show(io::IO, l::AbstractDeepEquilibriumNetwork) - return print(io, string(typeof(l).name.name), "() ", string(length(l.p)), " Trainable Parameters") -end - -Flux.trainable(d::AbstractDeepEquilibriumNetwork) = (d.p,) - -Base.deepcopy(op::DiffEqSensitivity.ZygotePullbackMultiplyOperator) = op +using ChainRulesCore, + ComponentArrays, + CUDA, + DiffEqBase, + DiffEqCallbacks, + DiffEqSensitivity, + Functors, + LinearAlgebra, + LinearSolve, + Lux, + MLUtils, + OrdinaryDiffEq, + SciMLBase, + Setfield, + Static, + Statistics, + SteadyStateDiffEq, + UnPack, + Zygote + +import DiffEqSensitivity: AbstractAdjointSensitivityAlgorithm +import Lux: AbstractExplicitContainerLayer, initialparameters, initialstates, parameterlength, statelength +import Random: AbstractRNG + +include("operator.jl") + +include("solvers/continuous.jl") +include("solvers/discrete.jl") +include("solvers/termination.jl") include("solve.jl") include("utils.jl") -include("solvers/broyden.jl") -include("solvers/limited_memory_broyden.jl") - -include("models/basics.jl") - +include("layers/core.jl") include("layers/jacobian_stabilization.jl") -include("layers/utils.jl") include("layers/deq.jl") -include("layers/sdeq.jl") include("layers/mdeq.jl") -include("layers/smdeq.jl") +include("layers/chain.jl") -include("models/chain.jl") - -include("losses.jl") +include("adjoint.jl") # DEQ Solvers export ContinuousDEQSolver, DiscreteDEQSolver, BroydenSolver, LimitedMemoryBroydenSolver # Utils -export NormalInitializer, SteadyStateAdjoint, get_and_clear_nfe!, compute_deq_jacobian_loss, DeepEquilibriumSolution, SupervisedLossContainer - -# Layers -export MultiParallelNet +export NormalInitializer, DeepEquilibriumAdjoint, compute_deq_jacobian_loss, DeepEquilibriumSolution -# DEQ Layers -export DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork, MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork -export DEQChain +export DeepEquilibriumNetwork, + SkipDeepEquilibriumNetwork, MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork, DEQChain end diff --git a/src/adjoint.jl b/src/adjoint.jl new file mode 100644 index 00000000..2a1f9ac4 --- /dev/null +++ b/src/adjoint.jl @@ -0,0 +1,80 @@ +neg(x::Any) = hasmethod(-, (typeof(x),)) ? -x : x +neg(nt::NamedTuple) = fmap(neg, nt) + +@noinline function DiffEqSensitivity.SteadyStateAdjointProblem( + sol::EquilibriumSolution, sensealg::DeepEquilibriumAdjoint, g::Nothing, dg; save_idxs=nothing +) + @unpack f, p, u0 = sol.prob + + diffcache, y = DiffEqSensitivity.adjointdiffcache(g, sensealg, false, sol, dg, f; quad=false, needs_jac=false) + + _save_idxs = save_idxs === nothing ? Colon() : save_idxs + if dg !== nothing + if typeof(_save_idxs) <: Number + diffcache.dg_val[_save_idxs] = dg[_save_idxs] + elseif typeof(dg) <: Number + @. diffcache.dg_val[_save_idxs] = dg + else + @. diffcache.dg_val[_save_idxs] = dg[_save_idxs] + end + end + + if check_adjoint_mode(sensealg, Val(:vanilla)) + # Solve the Linear Problem + _val, back = Zygote.pullback(x -> f(x, p, nothing), y) + s_val = size(_val) + op = ZygotePullbackMultiplyOperator{eltype(y),typeof(back),typeof(s_val)}(back, s_val) + linear_problem = LinearProblem(op, vec(diffcache.dg_val)) + λ = solve(linear_problem, sensealg.linsolve).u + elseif check_adjoint_mode(sensealg, Val(:jfb)) + # Jacobian Free Backpropagation + λ = diffcache.dg_val + else + error("Unknown adjoint mode") + end + + # Compute the VJP + _, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p) + dp = back(vec(λ))[1] + + return neg(dp) +end + +function DiffEqBase._concrete_solve_adjoint( + prob::SteadyStateProblem, alg, sensealg::DeepEquilibriumAdjoint, u0, p, args...; save_idxs=nothing, kwargs... +) + _prob = remake(prob; u0=u0, p=p) + sol = solve(_prob, alg, args...; kwargs...) + _save_idxs = save_idxs === nothing ? Colon() : save_idxs + + out = save_idxs === nothing ? sol : DiffEqBase.sensitivity_solution(sol, sol[_save_idxs]) + + function steadystatebackpass(Δ) + # Δ = dg/dx or diffcache.dg_val + # del g/del p = 0 + dp = adjoint_sensitivities(sol, alg; sensealg=sensealg, g=nothing, dg=Δ, save_idxs=save_idxs) + return ( + NoTangent(), + NoTangent(), + NoTangent(), + NoTangent(), + dp, + NoTangent(), + ntuple(_ -> NoTangent(), length(args))..., + ) + end + + return out, steadystatebackpass +end + +function DiffEqSensitivity._adjoint_sensitivities( + sol, sensealg::DeepEquilibriumAdjoint, alg, g, dg=nothing; abstol=1e-6, reltol=1e-3, kwargs... +) + return DiffEqSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) +end + +function DiffEqSensitivity._adjoint_sensitivities( + sol, sensealg::DeepEquilibriumAdjoint, alg; g=nothing, dg=nothing, abstol=1e-6, reltol=1e-3, kwargs... +) + return DiffEqSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) +end diff --git a/src/layers/chain.jl b/src/layers/chain.jl new file mode 100644 index 00000000..581d3a78 --- /dev/null +++ b/src/layers/chain.jl @@ -0,0 +1,48 @@ +""" + DEQChain(layers...) + +Sequence of layers divided into 3 chunks -- + +* `pre_deq` -- layers that are executed before DEQ is applied +* `deq` -- The Deep Equilibrium Layer +* `post_deq` -- layers that are executed after DEQ is applied + +Constraint: Must have one DEQ layer in `layers` +""" +struct DEQChain{P1,D,P2} <: AbstractExplicitContainerLayer{(:pre_deq, :deq, :post_deq)} + pre_deq::P1 + deq::D + post_deq::P2 +end + +function DEQChain(layers...) + pre_deq, post_deq, deq, encounter_deq = [], [], nothing, false + for l in layers + if l isa AbstractDeepEquilibriumNetwork || l isa AbstractSkipDeepEquilibriumNetwork + @assert !encounter_deq "Can have only 1 DEQ Layer in the Chain!!!" + deq = l + encounter_deq = true + continue + end + push!(encounter_deq ? post_deq : pre_deq, l) + end + @assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain" + pre_deq = length(pre_deq) == 0 ? NoOpLayer() : Chain(pre_deq...) + post_deq = length(post_deq) == 0 ? NoOpLayer() : Chain(post_deq...) + return DEQChain(pre_deq, deq, post_deq) +end + +function get_deq_return_type( + deq::DEQChain{P1,<:Union{MultiScaleDeepEquilibriumNetwork,MultiScaleSkipDeepEquilibriumNetwork}}, ::T +) where {P1,T} + return NTuple{length(deq.deq.scales),T} +end +get_deq_return_type(::DEQChain, ::T) where {T} = T + +function (deq::DEQChain)(x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) + T = get_deq_return_type(deq, x) + x1, st1 = deq.pre_deq(x, ps.pre_deq, st.pre_deq) + (x2::T, deq_soln), st2 = deq.deq(x1, ps.deq, st.deq) + x3, st3 = deq.post_deq(x2, ps.post_deq, st.post_deq) + return (x3, deq_soln), (pre_deq=st1, deq=st2, post_deq=st3) +end diff --git a/src/layers/core.jl b/src/layers/core.jl new file mode 100644 index 00000000..13745da2 --- /dev/null +++ b/src/layers/core.jl @@ -0,0 +1,53 @@ +abstract type AbstractDeepEquilibriumNetwork <: AbstractExplicitContainerLayer{(:model,)} end + +function initialstates(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork) + return (model=initialstates(rng, deq.model), fixed_depth=Val(0)) +end + +abstract type AbstractSkipDeepEquilibriumNetwork <: AbstractExplicitContainerLayer{(:model,:shortcut)} end + +function initialstates(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork) + return ( + model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut), fixed_depth=Val(0) + ) +end + +@inline check_unrolled_mode(::Val{0})::Bool = false +@inline check_unrolled_mode(::Val{d}) where {d} = (d >= 1)::Bool +@inline check_unrolled_mode(st::NamedTuple)::Bool = check_unrolled_mode(st.fixed_depth) +@inline get_unrolled_depth(::Val{d}) where {d} = d::Int +@inline get_unrolled_depth(st::NamedTuple)::Int = get_unrolled_depth(st.fixed_depth) + +ChainRulesCore.@non_differentiable check_unrolled_mode(::Any) +ChainRulesCore.@non_differentiable get_unrolled_depth(::Any) + +""" + DeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe) + +Stores the solution of a DeepEquilibriumNetwork and its variants. + +## Fields + * `z_star`: Steady-State or the value reached due to maxiters + * `u₀`: Initial Condition + * `residual`: Difference of the ``z^*`` and ``f(z^*, x)`` + * `jacobian_loss`: Jacobian Stabilization Loss (see individual networks to see how it can be computed) + * `nfe`: Number of Function Evaluations +""" +struct DeepEquilibriumSolution{T,R<:AbstractFloat} + z_star::T + u₀::T + residual::T + jacobian_loss::R + nfe::Int +end + +function Base.show(io::IO, l::DeepEquilibriumSolution) + print(io, "DeepEquilibriumSolution(") + print(io, "z_star: ", l.z_star) + print(io, ", initial_condition: ", l.u₀) + print(io, ", residual: ", l.residual) + print(io, ", jacobian_loss: ", l.jacobian_loss) + print(io, ", NFE: ", l.nfe) + print(io, ")") + return nothing +end \ No newline at end of file diff --git a/src/layers/deq.jl b/src/layers/deq.jl index 931c19ef..dfaceec0 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -1,17 +1,14 @@ """ - DeepEquilibriumNetwork(model, solver; jacobian_regularization::Bool=false, - p=nothing, sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), - kwargs...) + DeepEquilibriumNetwork(model, solver; jacobian_regularization::Bool=false, sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...) Deep Equilibrium Network as proposed in [baideep2019](@cite) ## Arguments -* `model`: Explicit Neural Network which takes 2 inputs +* `model`: Neural Network * `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) * `jacobian_regularization`: If true, Jacobian Loss is computed and stored in the [`DeepEquilibriumSolution`](@ref) -* `p`: Optional parameters for the `model` -* `sensealg`: See [`SteadyStateAdjoint`](@ref) +* `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) * `kwargs`: Additional Parameters that are directly passed to `solve` ## Example @@ -26,56 +23,183 @@ model = DeepEquilibriumNetwork( ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) ) -model(rand(Float32, 2, 1)) +rng = Random.default_rng() +ps, st = Lux.setup(rng, model) + +model(rand(Float32, 2, 1), ps, st) ``` See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) """ -struct DeepEquilibriumNetwork{J,M,P,RE,A,S,K} <: AbstractDeepEquilibriumNetwork - jacobian_regularization::Bool +struct DeepEquilibriumNetwork{J,M,A,S,K} <: AbstractDeepEquilibriumNetwork model::M - p::P - re::RE solver::A - kwargs::K sensealg::S - stats::DEQTrainingStats + kwargs::K +end + +function DeepEquilibriumNetwork( + model, solver; jacobian_regularization::Bool=false, sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs... +) + return DeepEquilibriumNetwork{jacobian_regularization,typeof(model),typeof(solver),typeof(sensealg),typeof(kwargs)}( + model, solver, sensealg, kwargs + ) +end + +function (deq::DeepEquilibriumNetwork{J})( + x::AbstractArray{T}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple +) where {J,T} + z = zero(x) + + if check_unrolled_mode(st) + # Pretraining without Fixed Point Solving + st_ = st.model + z_star = z + for _ in 1:get_unrolled_depth(st) + z_star, st_ = deq.model((z_star, x), ps, st_) + end + + residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps, st.model)[1]) + st = merge(st, (model=st_,)) + + return (z_star, DeepEquilibriumSolution(z_star, z, residual, 0.0f0, get_unrolled_depth(st))), st + end - function DeepEquilibriumNetwork(jacobian_regularization, model, p, re, solver, kwargs, sensealg, stats) - _p, re = destructure_parameters(model) - p = p === nothing ? _p : convert(typeof(_p), p) + st_ = st.model - return new{jacobian_regularization,typeof(model),typeof(p),typeof(re),typeof(solver), - typeof(sensealg),typeof(kwargs)}(jacobian_regularization, model, p, re, - solver, kwargs, sensealg, stats) + function dudt(u, p, t) + u_, st_ = deq.model((u, x), p, st_) + return u_ .- u end + + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) + sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) + z_star, st_ = deq.model((sol.u, x), ps, st.model) + + jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st.model, z_star, x) : T(0)) + residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps, st.model)[1]) + + st = merge(st, (model=st_,)) + + return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end -Flux.@functor DeepEquilibriumNetwork -function Base.show(io::IO, l::DeepEquilibriumNetwork{J}) where {J} - return print(io, "DeepEquilibriumNetwork(jacobian_regularization = $J) ", - string(length(l.p)), " Trainable Parameters") +""" + SkipDeepEquilibriumNetwork(model, shortcut, solver; jacobian_regularization::Bool=false, sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...) + +Skip Deep Equilibrium Network as proposed in [pal2022mixing](@cite) + +## Arguments + +* `model`: Neural Network +* `shortcut`: Shortcut for the network (pass `nothing` for SkipDEQV2) +* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) +* `jacobian_regularization`: If true, Jacobian Loss is computed and stored in the [`DeepEquilibriumSolution`](@ref) +* `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) +* `kwargs`: Additional Parameters that are directly passed to `solve` + +## Example + +```julia +# SkipDEQ +model = SkipDeepEquilibriumNetwork( + Parallel( + +, + Dense(2, 2; bias=false), + Dense(2, 2; bias=false) + ), + Dense(2, 2), + ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) +) + +rng = Random.default_rng() +ps, st = Lux.setup(rng, model) + +model(rand(Float32, 2, 1), ps, st) + +# SkipDEQV2 +model = SkipDeepEquilibriumNetwork( + Parallel( + +, + Dense(2, 2; bias=false), + Dense(2, 2; bias=false) + ), + nothing, + ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) +) + +rng = Random.default_rng() +ps, st = Lux.setup(rng, model) + +model(rand(Float32, 2, 1), ps, st) +``` + +See also: [`DeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) +""" +struct SkipDeepEquilibriumNetwork{J,M,Sh,A,S,K} <: AbstractSkipDeepEquilibriumNetwork + model::M + shortcut::Sh + solver::A + sensealg::S + kwargs::K end -function DeepEquilibriumNetwork(model, solver; jacobian_regularization::Bool=false, - p=nothing, sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - return DeepEquilibriumNetwork(jacobian_regularization, model, p, nothing, solver, - kwargs, sensealg, DEQTrainingStats(0)) +function SkipDeepEquilibriumNetwork( + model, + shortcut, + solver; + jacobian_regularization::Bool=false, + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), + kwargs..., +) + return SkipDeepEquilibriumNetwork{ + jacobian_regularization,typeof(model),typeof(shortcut),typeof(solver),typeof(sensealg),typeof(kwargs) + }( + model, shortcut, solver, sensealg, kwargs + ) end -function (deq::DeepEquilibriumNetwork)(x::AbstractArray{T}) where {T} - z = zero(x) - Zygote.@ignore deq.re(deq.p)(z, x) +function (deq::SkipDeepEquilibriumNetwork{J,M,S})( + x::AbstractArray{T}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple +) where {J,M,S,T} + z, st = if S == Nothing + z__, st__ = deq.model((zero(x), x), ps.model, st.model) + z__, merge(st, (model=st__,)) + else + z__, st__ = deq.shortcut(x, ps.shortcut, st.shortcut) + z__, merge(st, (shortcut=st__,)) + end + + if check_unrolled_mode(st) + # Pretraining without Fixed Point Solving + st_ = st.model + z_star = z + for _ in 1:get_unrolled_depth(st) + z_star, st_ = deq.model((z_star, x), ps.model, st_) + end + + residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps.model, st.model)[1]) + st = merge(st, (model=st_,)) - current_nfe = deq.stats.nfe + return (z_star, DeepEquilibriumSolution(z_star, z, residual, 0.0f0, get_unrolled_depth(st))), st + end + + st_ = st.model + + function dudt(u, p, t) + u_, st_ = deq.model((u, x), p, st_) + return u_ .- u + end - z_star = solve_steady_state_problem(deq.re, deq.p, x, z, deq.sensealg, deq.solver; dudt=nothing, - update_nfe=() -> (deq.stats.nfe += 1), deq.kwargs...) + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) + sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) + z_star, st_ = deq.model((sol.u, x), ps.model, st.model) - jac_loss = (deq.jacobian_regularization ? compute_deq_jacobian_loss(deq.re, deq.p, z_star, x) : T(0))::T + jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) + residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps.model, st.model)[1]) - residual = Zygote.@ignore z_star .- deq.re(deq.p)(z_star, x) + st = merge(st, (model=st_,)) - return z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, deq.stats.nfe - current_nfe) + return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end diff --git a/src/layers/jacobian_stabilization.jl b/src/layers/jacobian_stabilization.jl index ae6bbe89..dc2439c0 100644 --- a/src/layers/jacobian_stabilization.jl +++ b/src/layers/jacobian_stabilization.jl @@ -1,32 +1,8 @@ -gaussian_like(p::Array) = randn(eltype(p), size(p)) -gaussian_like(p::CuArray) = CUDA.randn(eltype(p), size(p)) - -Zygote.@nograd gaussian_like - -""" - compute_deq_jacobian_loss(re, p, z, x) - -Computes Jacobian Stabilization Loss ([bai2021stabilizing](@cite)). - -## Arguments - -* `re`: Constructs the model given the parameters `p`. -* `p`: Parameters of the model. -* `z`: Steady State. -* `x`: Input to the model. - -## Current Known Failure Modes - -1. Conv layers error out due to ForwardDiff on GPUs -2. If the model internally uses destructure/restructure eg. `WeightNorm` Layer, then this loss function will error out in the backward pass. -""" -function compute_deq_jacobian_loss(re, p::AbstractVector{T}, z::A, x::A) where {T,A<:AbstractArray} - d = length(z) - v = gaussian_like(z) - model = re(p) - - _, back = Zygote.pullback(model, z, x) - vjp_z, vjp_x = back(v) - # NOTE: This weird sum(zero, ...) ensures that we get zeros instead of nothings - return sum(abs2, vjp_z) / d + sum(zero, vjp_x) +# Doesn't work as of now +function compute_deq_jacobian_loss( + model, ps::ComponentArray, st::NamedTuple, z::AbstractArray, x::AbstractArray +) + l, back = Zygote.pullback(u -> model((u, x), ps, st)[1], z) + vjp_z = back(gaussian_like(l))[1] + return sum(abs2, vjp_z) / length(z) end diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index d90227ab..f6f9081a 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -1,19 +1,25 @@ +@generated function evaluate_unrolled_mdeq(model, z_star::NTuple{N}, x, ps, st, ::Val{depth}) where {N,depth} + calls = [] + for _ in 1:depth + push!(calls, :((z_star, st) = model(((z_star[1], x), z_star[2:($N)]...), ps, st))) + end + push!(calls, :(return z_star, st)) + return Expr(:block, calls...) +end + """ - MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, solver; - post_fuse_layers::Union{Tuple,Nothing}=nothing, p=nothing, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) + MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing,Tuple}, solver, scales; sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...) Multiscale Deep Equilibrium Network as proposed in [baimultiscale2020](@cite) ## Arguments -* `main_layers`: Tuple of Explicit Neural Networks. The first network needs to take 2 inputs, the other ones only take 1 input -* `mapping_layers`: Matrix of Explicit Neural Networks. The ``(i, j)^{th}`` network takes the output of ``i^{th}`` `main_layer` - and passes it to the ``j^{th}`` `main_layer` +* `main_layers`: Tuple of Neural Networks. The first network needs to take a tuple of 2 arrays, the other ones only take 1 input +* `mapping_layers`: Matrix of Neural Networks. The ``(i, j)^{th}`` network takes the output of ``i^{th}`` `main_layer` and passes it to the ``j^{th}`` `main_layer` +* `post_fuse_layer`: Tuple of Neural Networks. Each of the scales are passed through this layer * `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) -* `post_fuse_layers`: Tuple of Explicit Neural Networks. Applied after the `mapping_layers` (Default: `nothing`) -* `p`: Optional parameters for the `model` -* `sensealg`: See [`SteadyStateAdjoint`](@ref) +* `scales`: Output scales +* `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) * `kwargs`: Additional Parameters that are directly passed to `solve` ## Example @@ -21,111 +27,289 @@ Multiscale Deep Equilibrium Network as proposed in [baimultiscale2020](@cite) ```julia model = MultiScaleDeepEquilibriumNetwork( ( - Parallel(+, Dense(4, 4, tanh_fast), Dense(4, 4, tanh_fast)), - Dense(3, 3, tanh_fast), Dense(2, 2, tanh_fast), - Dense(1, 1, tanh_fast) + Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh) ), [ - NoOpLayer() Dense(4, 3, tanh_fast) Dense(4, 2, tanh_fast) Dense(4, 1, tanh_fast); - Dense(3, 4, tanh_fast) NoOpLayer() Dense(3, 2, tanh_fast) Dense(3, 1, tanh_fast); - Dense(2, 4, tanh_fast) Dense(2, 3, tanh_fast) NoOpLayer() Dense(2, 1, tanh_fast); - Dense(1, 4, tanh_fast) Dense(1, 3, tanh_fast) Dense(1, 2, tanh_fast) NoOpLayer() + NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh); + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh); + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh); + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() ], + nothing, ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0), + ((4,), (3,), (2,), (1,)), ) -model(rand(Float32, 4, 1)) +rng = Random.default_rng() +ps, st = Lux.setup(rng, model) +x = rand(rng, Float32, 4, 1) + +model(x, ps, st) ``` See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) """ -struct MultiScaleDeepEquilibriumNetwork{N,M1<:Parallel,M2<:Union{Chain,FChain},RE1,RE2,P,A,K,S} <: - AbstractDeepEquilibriumNetwork - main_layers::M1 - mapping_layers::M2 - main_layers_re::RE1 - mapping_layers_re::RE2 - p::P - ordered_split_idxs::NTuple{N,Int} +struct MultiScaleDeepEquilibriumNetwork{N,Sc,M,A,S,K} <: AbstractDeepEquilibriumNetwork + model::M solver::A - kwargs::K sensealg::S - stats::DEQTrainingStats + scales::Sc + kwargs::K +end + +function initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwork) + return ( + model=initialstates(rng, deq.model), + split_idxs=static(Tuple(vcat(0, cumsum(prod.(deq.scales))...))), + fixed_depth=Val(0), + initial_condition=zeros(Float32, 1, 1), + ) +end + +function MultiScaleDeepEquilibriumNetwork( + main_layers::Tuple, + mapping_layers::Matrix, + post_fuse_layer::Union{Nothing,Tuple}, + solver, + scales::NTuple{N,NTuple{L,Int64}}; + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), + kwargs..., +) where {N,L} + l1 = Parallel(nothing, main_layers...) + l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) + model = post_fuse_layer === nothing ? Chain(l1, l2) : Chain(l1, l2, Parallel(nothing, post_fuse_layer...)) + scales = static(scales) + return MultiScaleDeepEquilibriumNetwork{ + N,typeof(scales),typeof(model),typeof(solver),typeof(sensealg),typeof(kwargs) + }( + model, solver, sensealg, scales, kwargs + ) +end + +@generated function get_initial_condition_mdeq(::S, x::AbstractArray{T,N}, st::NamedTuple{fields}) where {S,T,N,fields} + scales = known(S) + sz = sum(prod.(scales)) + calls = [] + if :initial_condition ∈ fields + push!(calls, :(u0 = st[:initial_condition])) + push!(calls, :(($sz, size(x, $N)) == size(u0) && return u0, st)) + end + push!(calls, :(u0 = fill!(similar(x, $(sz), size(x, N)), $(T(0))))) + push!(calls, :(st = merge(st, (initial_condition=u0,))::typeof(st))) + push!(calls, :(return u0, st)) + return Expr(:block, calls...) +end - function MultiScaleDeepEquilibriumNetwork(main_layers::Parallel, mapping_layers::Union{Chain,FChain}, re1, re2, - p, ordered_split_idxs, solver::A, kwargs::K, sensealg::S, stats) where {A,K,S} - @assert length(mapping_layers) == 2 - @assert mapping_layers[1] isa MultiParallelNet +ChainRulesCore.@non_differentiable get_initial_condition_mdeq(::Any...) - p_main_layers, re_main_layers = destructure_parameters(main_layers) - p_mapping_layers, re_mapping_layers = destructure_parameters(mapping_layers) +function (deq::MultiScaleDeepEquilibriumNetwork{N})( + x::AbstractArray{T}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple +) where {N,T} + z, st = get_initial_condition_mdeq(deq.scales, x, st) - ordered_split_idxs = tuple(cumsum([0, length(p_main_layers), length(p_mapping_layers)])...) + if check_unrolled_mode(st) + z_star = split_and_reshape(z, st.split_idxs, deq.scales) + z_star, st_ = evaluate_unrolled_mdeq(deq.model, z_star, x, ps, st.model, st.fixed_depth) - p = p === nothing ? vcat(p_main_layers, p_mapping_layers) : convert(typeof(p_main_layers), p) + residual = ignore_derivatives( + vcat(flatten.(z_star)...) .- + vcat(flatten.(evaluate_unrolled_mdeq(deq.model, z_star, x, ps, st_, Val(1))[1])...), + ) + st__ = merge(st, (model=st_,)) - return new{length(ordered_split_idxs), - typeof.((main_layers, mapping_layers, re_main_layers, re_mapping_layers, p))..., - A,K,S}(main_layers, mapping_layers, re_main_layers, re_mapping_layers, p, ordered_split_idxs, - solver, kwargs, sensealg, stats) + return ( + (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, get_unrolled_depth(st))), + st__, + ) end -end -Flux.@functor MultiScaleDeepEquilibriumNetwork + st_ = st.model -function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, solver; - post_fuse_layers::Union{Tuple,Nothing}=nothing, p=nothing, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - mapping_layers = if post_fuse_layers === nothing - @assert size(mapping_layers, 1) == size(mapping_layers, 2) == length(main_layers) - FChain(MultiParallelNet(Parallel.(+, map(x -> tuple(x...), eachcol(mapping_layers)))...), NoOpLayer()) - else - @assert size(mapping_layers, 1) == size(mapping_layers, 2) == length(main_layers) == length(post_fuse_layers) - FChain(MultiParallelNet(Parallel.(+, map(x -> tuple(x...), eachcol(mapping_layers)))...), - Parallel(flatten_merge, post_fuse_layers...)) + function dudt_(u, p, t) + u_split = split_and_reshape(u, st.split_idxs, deq.scales) + u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st_) + return u_, st_ end - main_layers = Parallel(flatten_merge, main_layers...) + dudt(u, p, t) = vcat(flatten.(dudt_(u, p, t)[1])...) .- u + + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) + sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) + z_star, st_ = dudt_(sol.u, ps, nothing) + + residual = ignore_derivatives(dudt(sol.u, ps, nothing)) + + st__ = merge(st, (model=st_,)) + + return ((z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st__) +end + +""" + MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing,Tuple}, shortcut_layers::Union{Nothing,Tuple}, solver, scales; sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...) + +Multiscale Deep Equilibrium Network as proposed in [baimultiscale2020](@cite) combined with Skip Deep Equilibrium Network as proposed in [pal2022mixing](@cite) + +## Arguments + +* `main_layers`: Tuple of Neural Networks. The first network needs to take a tuple of 2 arrays, the other ones only take 1 input +* `mapping_layers`: Matrix of Neural Networks. The ``(i, j)^{th}`` network takes the output of ``i^{th}`` `main_layer` and passes it to the ``j^{th}`` `main_layer` +* `post_fuse_layer`: Tuple of Neural Networks. Each of the scales are passed through this layer +* `shortcut_layers`: Shortcut for the network (pass `nothing` for SkipDEQV2) +* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) +* `scales`: Output scales +* `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) +* `kwargs`: Additional Parameters that are directly passed to `solve` + +## Example + +```julia +# MSkipDEQ +model = MultiScaleSkipDeepEquilibriumNetwork( + ( + Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh), + ), + [ + NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() + ], + nothing, + (Dense(4, 4, tanh), Dense(4, 3, tanh), Dense(4, 2, tanh), Dense(4, 1, tanh)), + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), +) + +rng = Random.default_rng() +ps, st = Lux.setup(rng, model) +x = rand(rng, Float32, 4, 2) + +model(x, ps, st) + +# MSkipDEQV2 +model = MultiScaleSkipDeepEquilibriumNetwork( + ( + Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh), + ), + [ + NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() + ], + nothing, + nothing, + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), + ) - return MultiScaleDeepEquilibriumNetwork(main_layers, mapping_layers, nothing, nothing, p, - nothing, solver, kwargs, sensealg, DEQTrainingStats(0)) +rng = Random.default_rng() +ps, st = Lux.setup(rng, model) +x = rand(rng, Float32, 4, 2) + +model(x, ps, st) +``` + +See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref) +""" +struct MultiScaleSkipDeepEquilibriumNetwork{N,Sc,M,Sh,A,S,K} <: AbstractSkipDeepEquilibriumNetwork + model::M + shortcut::Sh + solver::A + sensealg::S + scales::Sc + kwargs::K end -function (mdeq::MultiScaleDeepEquilibriumNetwork)(x::AbstractArray{T}) where {T} - current_nfe = mdeq.stats.nfe +function initialstates(rng::AbstractRNG, deq::MultiScaleSkipDeepEquilibriumNetwork) + return ( + model=initialstates(rng, deq.model), + shortcut=initialstates(rng, deq.shortcut), + split_idxs=static(Tuple(vcat(0, cumsum(prod.(deq.scales))...))), + fixed_depth=Val(0), + initial_condition=zeros(Float32, 1, 1), + ) +end - z = zero(x) - initial_conditions = Zygote.@ignore map(l -> l(z), map(l -> l.layers[1], mdeq.mapping_layers[1].layers)) - u_sizes = Zygote.@ignore size.(initial_conditions) - u_split_idxs = Zygote.@ignore vcat(0, cumsum(length.(initial_conditions) .÷ size(x, ndims(x)))...) - u0 = Zygote.@ignore vcat(Flux.flatten.(initial_conditions)...) +function MultiScaleSkipDeepEquilibriumNetwork( + main_layers::Tuple, + mapping_layers::Matrix, + post_fuse_layer::Union{Nothing,Tuple}, + shortcut_layers::Union{Nothing,Tuple}, + solver, + scales; + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), + kwargs..., +) + l1 = Parallel(nothing, main_layers...) + l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) + model = post_fuse_layer === nothing ? Chain(l1, l2) : Chain(l1, l2, Parallel(nothing, post_fuse_layer...)) + shortcut = shortcut_layers === nothing ? nothing : Parallel(nothing, shortcut_layers...) + scales = static(scales) + return MultiScaleSkipDeepEquilibriumNetwork{ + length(scales),typeof(scales),typeof(model),typeof(shortcut),typeof(solver),typeof(sensealg),typeof(kwargs) + }( + model, shortcut, solver, sensealg, scales, kwargs + ) +end - N = length(u_sizes) - update_is_variational_hidden_dropout_mask_reset_allowed(false) +function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,Sc,M,Sh})( + x::AbstractArray{T}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple +) where {N,Sc,M,Sh,T} + z, st = if Sh == Nothing + u0, st_ = get_initial_condition_mdeq(deq.scales, x, st) + u0_ = split_and_reshape(u0, st.split_idxs, deq.scales) + z0, st__ = deq.model(((u0_[1], x), u0_[2:N]...), ps.model, st_.model) + (vcat(flatten.(z0)...), merge(st_, (model=st__,))) + else + z0, st_ = deq.shortcut(x, ps.shortcut, st.shortcut) + (vcat(flatten.(z0)...), merge(st, (shortcut=st_,))) + end - function dudt_(u, _p) - mdeq.stats.nfe += 1 + if check_unrolled_mode(st) + z_star = split_and_reshape(z, st.split_idxs, deq.scales) + z_star, st_ = evaluate_unrolled_mdeq(deq.model, z_star, x, ps.model, st.model, st.fixed_depth) - uₛ = split_array_by_indices(u, u_split_idxs) - p1, p2 = split_array_by_indices(_p, mdeq.ordered_split_idxs) + residual = ignore_derivatives( + vcat(flatten.(z_star)...) .- + vcat(flatten.(evaluate_unrolled_mdeq(deq.model, z_star, x, ps.model, st_, Val(1))[1])...), + ) + st__ = merge(st, (model=st_,)) - u_reshaped = ntuple(i -> reshape(uₛ[i], u_sizes[i]), N) + return ( + (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, get_unrolled_depth(st))), + st__, + ) + end - main_layers_output = mdeq.main_layers_re(p1)((u_reshaped[1], x), u_reshaped[2:end]...) + st_ = st.model - return mdeq.mapping_layers_re(p2)(main_layers_output) + function dudt_(u, p, t) + u_split = split_and_reshape(u, st.split_idxs, deq.scales) + u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st_) + return u_, st_ end - dudt(u, _p, t) = vcat(Flux.flatten.(dudt_(u, _p))...) .- u + dudt(u, p, t) = vcat(flatten.(dudt_(u, p, t)[1])...) .- u - ssprob = SteadyStateProblem(dudt, u0, mdeq.p) - res = solve(ssprob, mdeq.solver; u0=u0, sensealg=mdeq.sensealg, mdeq.kwargs...).u + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) + sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) + z_star, st_ = dudt_(sol.u, ps.model, nothing) - x_ = dudt_(res, mdeq.p) + residual = ignore_derivatives(dudt(sol.u, ps.model, nothing)) - residual = Zygote.@ignore Tuple(map((iu) -> reshape(iu[2], u_sizes[iu[1]]), - enumerate(split_array_by_indices(dudt(res, mdeq.p, nothing), u_split_idxs)))) - update_is_variational_hidden_dropout_mask_reset_allowed(true) + st__ = merge(st, (model=st_,)) - return x_, DeepEquilibriumSolution(x_, initial_conditions, residual, T(0), mdeq.stats.nfe - current_nfe) + return ((z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st) end diff --git a/src/layers/sdeq.jl b/src/layers/sdeq.jl deleted file mode 100644 index 06728a89..00000000 --- a/src/layers/sdeq.jl +++ /dev/null @@ -1,130 +0,0 @@ -""" - SkipDeepEquilibriumNetwork(model, shortcut, solver; p=nothing, jacobian_regularization::Bool=false, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - SkipDeepEquilibriumNetwork(model, solver; p=nothing, jacobian_regularization::Bool=false, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - -Skip Deep Equilibrium Network as proposed in [pal2022mixing](@cite) - -## Arguments - -* `model`: Explicit Neural Network which takes 2 inputs -* `shortcut`: Shortcut for the network (If not given, then we create SkipDEQV2) -* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) -* `jacobian_regularization`: If true, Jacobian Loss is computed and stored in the [`DeepEquilibriumSolution`](@ref) -* `p`: Optional parameters for the `model` -* `sensealg`: See [`SteadyStateAdjoint`](@ref) -* `kwargs`: Additional Parameters that are directly passed to `solve` - -## Example - -```julia -# SkipDEQ -model = SkipDeepEquilibriumNetwork( - Parallel( - +, - Dense(2, 2; bias=false), - Dense(2, 2; bias=false) - ), - Dense(2, 2), - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) -) - -model(rand(Float32, 2, 1)) - -# SkipDEQV2 -model = SkipDeepEquilibriumNetwork( - Parallel( - +, - Dense(2, 2; bias=false), - Dense(2, 2; bias=false) - ), - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) -) - -model(rand(Float32, 2, 1)) -``` - -See also: [`DeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) -""" -struct SkipDeepEquilibriumNetwork{M,S,J,P,RE1,RE2,A,Se,K} <: AbstractDeepEquilibriumNetwork - jacobian_regularization::Bool - model::M - shortcut::S - p::P - re1::RE1 - re2::RE2 - split_idx::Int - solver::A - kwargs::K - sensealg::Se - stats::DEQTrainingStats - - function SkipDeepEquilibriumNetwork(jacobian_regularization, model, shortcut, p, re1, re2, - split_idx, solver, kwargs, sensealg, stats) - p1, re1 = destructure_parameters(model) - split_idx = length(p1) - p2, re2 = shortcut === nothing ? ((eltype(p1))[], nothing) : destructure_parameters(shortcut) - - p = p === nothing ? vcat(p1, p2) : eltype(p1).(p) - - return new{typeof(model),typeof(shortcut),jacobian_regularization,typeof(p),typeof(re1), - typeof(re2),typeof(solver),typeof(sensealg),typeof(kwargs)}(jacobian_regularization, model, shortcut, p, - re1, re2, split_idx, solver, kwargs, - sensealg, stats) - end -end - -Flux.@functor SkipDeepEquilibriumNetwork - -function Base.show(io::IO, l::SkipDeepEquilibriumNetwork{M,S,J}) where {M,S,J} - shortcut_ps = l.split_idx == length(l.p) ? 0 : length(l.p) - l.split_idx - return print(io, "SkipDeepEquilibriumNetwork(jacobian_regularization = $J, ", - "shortcut_parameter_count = $shortcut_ps) ", string(length(l.p)), " Trainable Parameters") -end - -function SkipDeepEquilibriumNetwork(model, shortcut, solver; p=nothing, jacobian_regularization::Bool=false, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - return SkipDeepEquilibriumNetwork(jacobian_regularization, model, shortcut, p, nothing, - nothing, 0, solver, kwargs, sensealg, DEQTrainingStats(0)) -end - -function SkipDeepEquilibriumNetwork(model, solver; p=nothing, jacobian_regularization::Bool=false, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - return SkipDeepEquilibriumNetwork(jacobian_regularization, model, nothing, p, nothing, - nothing, 0, solver, kwargs, sensealg, DEQTrainingStats(0)) -end - -function (deq::SkipDeepEquilibriumNetwork)(x::AbstractArray{T}) where {T} - p1, p2 = deq.p[1:(deq.split_idx)], deq.p[(deq.split_idx + 1):end] - z = deq.re2(p2)(x)::typeof(x) - - current_nfe = deq.stats.nfe - - # Dummy call to ensure that mask is generated - Zygote.@ignore _ = deq.re1(p1)(z, x) - - z_star = solve_steady_state_problem(deq.re1, p1, x, z, deq.sensealg, deq.solver; dudt=nothing, - update_nfe=() -> (deq.stats.nfe += 1), deq.kwargs...) - - jac_loss = (deq.jacobian_regularization ? compute_deq_jacobian_loss(deq.re1, p1, z_star, x) : T(0)) ::T - - residual = Zygote.@ignore z_star .- deq.re1(p1)(z_star, x) - - return z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, deq.stats.nfe - current_nfe) -end - -function (deq::SkipDeepEquilibriumNetwork{M,Nothing})(x::AbstractArray{T}) where {M,T} - z = deq.re1(deq.p)(zero(x), x)::typeof(x) - - current_nfe = deq.stats.nfe - - z_star = solve_steady_state_problem(deq.re1, deq.p, x, z, deq.sensealg, deq.solver; dudt=nothing, - update_nfe=() -> (deq.stats.nfe += 1), deq.kwargs...) - - jac_loss = (deq.jacobian_regularization ? compute_deq_jacobian_loss(deq.re1, deq.p, z_star, x) : T(0)) ::T - - residual = Zygote.@ignore z_star .- deq.re1(deq.p)(z_star, x) - - return z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, deq.stats.nfe - current_nfe) -end diff --git a/src/layers/smdeq.jl b/src/layers/smdeq.jl deleted file mode 100644 index 7aa47a77..00000000 --- a/src/layers/smdeq.jl +++ /dev/null @@ -1,235 +0,0 @@ -""" - MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, shortcut_layers::Tuple, - solver; post_fuse_layers::Union{Tuple,Nothing}=nothing, p=nothing, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, solver; - post_fuse_layers::Union{Tuple,Nothing}=nothing, p=nothing, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - -Multiscale Deep Equilibrium Network as proposed in [baimultiscale2020](@cite) - -## Arguments - -* `main_layers`: Tuple of Explicit Neural Networks. The first network needs to take 2 inputs, the other ones only take 1 input -* `mapping_layers`: Matrix of Explicit Neural Networks. The ``(i, j)^{th}`` network takes the output of ``i^{th}`` `main_layer` - and passes it to the ``j^{th}`` `main_layer` -* `shortcut_layers`: Shortcuts for the network (If not given, then we create SkipDEQV2) -* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) -* `post_fuse_layers`: Tuple of Explicit Neural Networks. Applied after the `mapping_layers` (Default: `nothing`) -* `p`: Optional parameters for the `model` -* `sensealg`: See [`SteadyStateAdjoint`](@ref) -* `kwargs`: Additional Parameters that are directly passed to `solve` - -## Example - -```julia -# MSkipDEQ -model = MultiScaleSkipDeepEquilibriumNetwork( - ( - Parallel(+, Dense(4, 4, tanh_fast), Dense(4, 4, tanh_fast)), - Dense(3, 3, tanh_fast), Dense(2, 2, tanh_fast), - Dense(1, 1, tanh_fast) - ), - [ - NoOpLayer() Dense(4, 3, tanh_fast) Dense(4, 2, tanh_fast) Dense(4, 1, tanh_fast); - Dense(3, 4, tanh_fast) NoOpLayer() Dense(3, 2, tanh_fast) Dense(3, 1, tanh_fast); - Dense(2, 4, tanh_fast) Dense(2, 3, tanh_fast) NoOpLayer() Dense(2, 1, tanh_fast); - Dense(1, 4, tanh_fast) Dense(1, 3, tanh_fast) Dense(1, 2, tanh_fast) NoOpLayer() - ], - ( - Dense(4, 4, tanh_fast), - Dense(4, 3, tanh_fast), - Dense(4, 2, tanh_fast), - Dense(4, 1, tanh_fast) - ), - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0), -) - -model(rand(Float32, 4, 1)) - - -# MSkipDEQV2 -model = MultiScaleSkipDeepEquilibriumNetwork( - ( - Parallel(+, Dense(4, 4, tanh_fast), Dense(4, 4, tanh_fast)), - Dense(3, 3, tanh_fast), Dense(2, 2, tanh_fast), - Dense(1, 1, tanh_fast) - ), - [ - NoOpLayer() Dense(4, 3, tanh_fast) Dense(4, 2, tanh_fast) Dense(4, 1, tanh_fast); - Dense(3, 4, tanh_fast) NoOpLayer() Dense(3, 2, tanh_fast) Dense(3, 1, tanh_fast); - Dense(2, 4, tanh_fast) Dense(2, 3, tanh_fast) NoOpLayer() Dense(2, 1, tanh_fast); - Dense(1, 4, tanh_fast) Dense(1, 3, tanh_fast) Dense(1, 2, tanh_fast) NoOpLayer() - ], - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0), -) - -model(rand(Float32, 4, 1)) -``` - -See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref) -""" -struct MultiScaleSkipDeepEquilibriumNetwork{M3<:Union{Nothing,Parallel},N,M1<:Parallel,M2<:Union{Chain,FChain},RE1, - RE2,RE3,P,A,K,S} <: AbstractDeepEquilibriumNetwork - main_layers::M1 - mapping_layers::M2 - shortcut_layers::M3 - main_layers_re::RE1 - mapping_layers_re::RE2 - shortcut_layers_re::RE3 - p::P - ordered_split_idxs::NTuple{N,Int} - solver::A - kwargs::K - sensealg::S - stats::DEQTrainingStats - - function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Parallel, - mapping_layers::Union{Chain,FChain}, - shortcut_layers::Union{Nothing,Parallel}, re1, re2, re3, p, - ordered_split_idxs, solver::A, kwargs::K, sensealg::S, - stats) where {A,K,S} - @assert length(mapping_layers) == 2 - @assert mapping_layers[1] isa MultiParallelNet - - p_main_layers, re_main_layers = destructure_parameters(main_layers) - p_mapping_layers, re_mapping_layers = destructure_parameters(mapping_layers) - p_shortcut_layers, re_shortcut_layers = shortcut_layers === nothing ? ([], nothing) : - destructure_parameters(shortcut_layers) - - ordered_split_idxs = tuple(cumsum([0, length(p_main_layers), length(p_mapping_layers), - length(p_shortcut_layers)])...) - - p = p === nothing ? vcat(p_main_layers, p_mapping_layers, p_shortcut_layers) : convert(typeof(p_main_layers), p) - - return new{typeof(shortcut_layers),length(ordered_split_idxs), - typeof.((main_layers, mapping_layers, re_main_layers, re_mapping_layers, re_shortcut_layers, p))..., - A,K,S}(main_layers, mapping_layers, shortcut_layers, re_main_layers, - re_mapping_layers, re_shortcut_layers, p, ordered_split_idxs, solver, kwargs, sensealg, stats) - end -end - -Flux.@functor MultiScaleSkipDeepEquilibriumNetwork - -function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, shortcut_layers::Tuple, - solver; post_fuse_layers::Union{Tuple,Nothing}=nothing, p=nothing, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - mapping_layers = if post_fuse_layers === nothing - @assert size(mapping_layers, 1) == size(mapping_layers, 2) == length(main_layers) == length(shortcut_layers) - Chain(MultiParallelNet(Parallel.(+, map(x -> tuple(x...), eachcol(mapping_layers)))...), NoOpLayer()) - else - @assert size(mapping_layers, 1) == - size(mapping_layers, 2) == - length(main_layers) == - length(post_fuse_layers) == - length(shortcut_layers) - Chain(MultiParallelNet(Parallel.(+, map(x -> tuple(x...), eachcol(mapping_layers)))...), - Parallel(flatten_merge, post_fuse_layers...)) - end - - main_layers = Parallel(flatten_merge, main_layers...) - shortcut_layers = Parallel(flatten_merge, shortcut_layers...) - - return MultiScaleSkipDeepEquilibriumNetwork(main_layers, mapping_layers, shortcut_layers, - nothing, nothing, nothing, p, nothing, solver, kwargs, sensealg, - DEQTrainingStats(0)) -end - -function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, solver; - post_fuse_layers::Union{Tuple,Nothing}=nothing, p=nothing, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - mapping_layers = if post_fuse_layers === nothing - @assert size(mapping_layers, 1) == size(mapping_layers, 2) == length(main_layers) - Chain(MultiParallelNet(Parallel.(+, map(x -> tuple(x...), eachcol(mapping_layers)))...), NoOpLayer()) - else - @assert size(mapping_layers, 1) == size(mapping_layers, 2) == length(main_layers) == length(post_fuse_layers) - Chain(MultiParallelNet(Parallel.(+, map(x -> tuple(x...), eachcol(mapping_layers)))...), - Parallel(flatten_merge, post_fuse_layers...)) - end - - main_layers = Parallel(flatten_merge, main_layers...) - - return MultiScaleSkipDeepEquilibriumNetwork(main_layers, mapping_layers, nothing, nothing, - nothing, nothing, p, nothing, solver, kwargs, sensealg, - DEQTrainingStats(0)) -end - -function (mdeq::MultiScaleSkipDeepEquilibriumNetwork)(x::AbstractArray{T}) where {T} - current_nfe = mdeq.stats.nfe - - p1, p2, p3 = split_array_by_indices(mdeq.p, mdeq.ordered_split_idxs) - initial_conditions = mdeq.shortcut_layers_re(p3)(x) - u_sizes = size.(initial_conditions) - u_split_idxs = vcat(0, cumsum(length.(initial_conditions) .÷ size(x, ndims(x)))...) - u0 = Zygote.@ignore vcat(Flux.flatten.(initial_conditions)...) - - N = length(u_sizes) - update_is_variational_hidden_dropout_mask_reset_allowed(false) - - function dudt_(u, _p) - mdeq.stats.nfe += 1 - - uₛ = split_array_by_indices(u, u_split_idxs) - p1, p2, _ = split_array_by_indices(_p, mdeq.ordered_split_idxs) - - u_reshaped = ntuple(i -> reshape(uₛ[i], u_sizes[i]), N) - - main_layers_output = mdeq.main_layers_re(p1)((u_reshaped[1], x), u_reshaped[2:end]...) - - return mdeq.mapping_layers_re(p2)(main_layers_output) - end - - dudt(u, _p, t) = vcat(Flux.flatten.(dudt_(u, _p))...) .- u - - ssprob = SteadyStateProblem(dudt, u0, mdeq.p) - res = solve(ssprob, mdeq.solver; u0=u0, sensealg=mdeq.sensealg, mdeq.kwargs...).u - x_ = dudt_(res, mdeq.p) - residual = Zygote.@ignore Tuple(map((iu) -> reshape(iu[2], u_sizes[iu[1]]), - enumerate(split_array_by_indices(dudt(res, mdeq.p, nothing), u_split_idxs)))) - - update_is_variational_hidden_dropout_mask_reset_allowed(true) - - return x_, DeepEquilibriumSolution(x_, initial_conditions, residual, T(0), mdeq.stats.nfe - current_nfe) -end - -function (mdeq::MultiScaleSkipDeepEquilibriumNetwork{Nothing})(x::AbstractArray{T}) where {T} - current_nfe = mdeq.stats.nfe - - p1, p2 = split_array_by_indices(mdeq.p, mdeq.ordered_split_idxs) - - _initial_conditions = Zygote.@ignore [l(x) for l in map(l -> l.layers[1], mdeq.mapping_layers[1].layers)] - _initial_conditions = mdeq.mapping_layers_re(p2)((x, zero.(_initial_conditions[2:end])...)) - initial_conditions = mdeq.main_layers_re(p1)((zero(_initial_conditions[1]), _initial_conditions[1]), - _initial_conditions[2:end]...) - u_sizes = size.(initial_conditions) - u_split_idxs = vcat(0, cumsum(length.(initial_conditions) .÷ size(x, ndims(x)))...) - u0 = vcat(Flux.flatten.(initial_conditions)...) - - N = length(u_sizes) - update_is_variational_hidden_dropout_mask_reset_allowed(false) - - function dudt_(u, _p) - mdeq.stats.nfe += 1 - - uₛ = split_array_by_indices(u, u_split_idxs) - p1, p2, _ = split_array_by_indices(_p, mdeq.ordered_split_idxs) - - u_reshaped = ntuple(i -> reshape(uₛ[i], u_sizes[i]), N) - - main_layers_output = mdeq.main_layers_re(p1)((u_reshaped[1], x), u_reshaped[2:end]...) - - return mdeq.mapping_layers_re(p2)(main_layers_output) - end - - dudt(u, _p, t) = vcat(Flux.flatten.(dudt_(u, _p))...) .- u - - ssprob = SteadyStateProblem(dudt, u0, mdeq.p) - res = solve(ssprob, mdeq.solver; u0=u0, sensealg=mdeq.sensealg, mdeq.kwargs...).u - x_ = dudt_(res, mdeq.p) - residual = Zygote.@ignore Tuple(map((iu) -> reshape(iu[2], u_sizes[iu[1]]), - enumerate(split_array_by_indices(dudt(res, mdeq.p, nothing), u_split_idxs)))) - - update_is_variational_hidden_dropout_mask_reset_allowed(true) - - return x_, DeepEquilibriumSolution(x_, initial_conditions, residual, T(0), mdeq.stats.nfe - current_nfe) -end diff --git a/src/layers/utils.jl b/src/layers/utils.jl deleted file mode 100644 index 3be36acf..00000000 --- a/src/layers/utils.jl +++ /dev/null @@ -1,76 +0,0 @@ -""" - DeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe) - -Stores the solution of a DeepEquilibriumNetwork and its variants. - -## Fields - -* `z_star`: Steady-State or the value reached due to maxiters -* `u₀`: Initial Condition -* `residual`: Difference of the ``z^*`` and ``f(z^*, x)`` -* `jacobian_loss`: Jacobian Stabilization Loss (see individual networks to see how it can be computed) -* `nfe`: Number of Function Evaluations -""" -struct DeepEquilibriumSolution{T,R<:AbstractFloat} - z_star::T - u₀::T - residual::T - jacobian_loss::R - nfe::Int -end - -function Base.show(io::IO, l::DeepEquilibriumSolution) - println(io, "DeepEquilibriumSolution(") - println(io, "\tz_star: ", l.z_star) - println(io, "\tinitial_condition: ", l.u₀) - println(io, "\tresidual: ", l.residual) - println(io, "\tjacobian_loss: ", l.jacobian_loss) - println(io, "\tNFE: ", l.nfe) - print(io, ")") - return nothing -end - - -function solve_steady_state_problem(re, p, x, u0, sensealg, args...; dudt=nothing, update_nfe=() -> (), kwargs...) - # Solving the equation f(u) - u = du = 0 - update_is_variational_hidden_dropout_mask_reset_allowed(false) - - dudt_ = if dudt === nothing - function (u, _p, t) - update_nfe() - return re(_p)(u, x) .- u - end - else - dudt - end - - ssprob = SteadyStateProblem(dudt_, u0, p) - sol = solve(ssprob, args...; u0=u0, sensealg=sensealg, kwargs...) - - z = re(p)(sol.u, x)::typeof(x) - update_nfe() - - update_is_variational_hidden_dropout_mask_reset_allowed(true) - - return z -end - -function solve_depth_k_neural_network(re, p, x, u0, depth) - update_is_variational_hidden_dropout_mask_reset_allowed(false) - model = re(p) - for _ in 1:depth - u0 = model(u0, x) - end - update_is_variational_hidden_dropout_mask_reset_allowed(true) - return u0 -end - - -flatten(x::AbstractArray{T,N}) where {T,N} = reshape(x, :, size(x, N)) - -Zygote.@adjoint function flatten(x::AbstractArray{T,N}) where {T,N} - s = size(x) - res = reshape(x, :, s[N]) - flatten_sensitivity(Δ) = (reshape(Δ, s),) - return res, flatten_sensitivity -end diff --git a/src/losses.jl b/src/losses.jl deleted file mode 100644 index a7b81277..00000000 --- a/src/losses.jl +++ /dev/null @@ -1,43 +0,0 @@ -""" - SupervisedLossContainer(loss_function) - SupervisedLossContainer(loss_function, λ, λⱼ) - -A container class for supervised loss functions. -""" -Base.@kwdef struct SupervisedLossContainer{L,T} - loss_function::L - λ::T = 0.0f0 - λⱼ::T = 0.0f0 -end - -function (lc::SupervisedLossContainer)(soln::DeepEquilibriumSolution) - return lc.λ * mean(abs, soln.u₀ .- soln.z_star) + lc.λⱼ * soln.jacobian_loss -end - -function (lc::SupervisedLossContainer)(soln::DeepEquilibriumSolution{T}) where {T<:Tuple} - return lc.λ * mapreduce((x, y) -> mean(abs, x .- y), +, soln.u₀, soln.z_star) + - lc.λⱼ * soln.jacobian_loss -end - -function (lc::SupervisedLossContainer)(model::Union{DeepEquilibriumNetwork,SkipDeepEquilibriumNetwork,DEQChain}, x, y; - kwargs...) - ŷ, soln = model(x; kwargs...) - return lc.loss_function(ŷ, y) + lc(soln) -end - -function (lc::SupervisedLossContainer)(model::Union{MultiScaleDeepEquilibriumNetwork, - MultiScaleSkipDeepEquilibriumNetwork}, x, ys::Tuple; kwargs...) - yŝ, soln = model(x; kwargs...) - return mapreduce(lc.loss_function, +, ys, yŝ) + lc(soln) -end - -function (lc::SupervisedLossContainer)(model::Union{MultiScaleDeepEquilibriumNetwork, - MultiScaleSkipDeepEquilibriumNetwork}, x, y; kwargs...) - yŝ, soln = model(x; kwargs...) - return sum(Base.Fix2(lc.loss_function, y), yŝ) + lc(soln) -end - -# Default fallback -function (lc::SupervisedLossContainer)(model, x, y; kwargs...) - return lc.loss_function(model(x; kwargs...), y) -end diff --git a/src/models/basics.jl b/src/models/basics.jl deleted file mode 100644 index 7701f7c8..00000000 --- a/src/models/basics.jl +++ /dev/null @@ -1,46 +0,0 @@ -""" - MultiParallelNet(layers...) - MultiParallelNet(layers::Tuple) - MultiParallelNet(layers::Vector) - -Creates a MultiParallelNet mostly used for MultiScale Models. It takes a list of inputs -and passes all of them through each `layer` and returns a tuple of outputs. - -## Example - -``` -Model := MultiParallelNet(L1, L2, L3) - -Model(X1, X2) := (Model.L1(X1, X2), Model.L2(X1, X2), Model.L3(X1, X2)) -``` -""" -struct MultiParallelNet{L} - layers::L - - function MultiParallelNet(args...) - layers = tuple(args...) - return new{typeof(layers)}(layers) - end - - MultiParallelNet(layers::Tuple) = new{typeof(layers)}(layers) - - MultiParallelNet(layers::Vector) = MultiParallelNet(layers...) -end - -Flux.@functor MultiParallelNet - -function (mpn::MultiParallelNet)(x::Union{Tuple,Vector}) - buf = Zygote.Buffer(Vector{Any}(undef, length(mpn.layers))) - for (i, l) in enumerate(mpn.layers) - buf[i] = l(x...) - end - return Tuple(copy(buf)) -end - -function (mpn::MultiParallelNet)(args...) - buf = Zygote.Buffer(Vector{Any}(undef, length(mpn.layers))) - for (i, l) in enumerate(mpn.layers) - buf[i] = l(args...) - end - return Tuple(copy(buf)) -end diff --git a/src/models/chain.jl b/src/models/chain.jl deleted file mode 100644 index 1fe93c92..00000000 --- a/src/models/chain.jl +++ /dev/null @@ -1,61 +0,0 @@ -# Default to nothing happening -reset_mask!(x) = nothing - -""" - DEQChain(pre_deq, deq, post_deq) - DEQChain(layers...) - -A Sequential Model containing a DEQ. - -!!! note - The Model should contain exactly 1 `AbstractDEQ` Layer -""" -struct DEQChain{P1,D<:AbstractDeepEquilibriumNetwork,P2} - pre_deq::P1 - deq::D - post_deq::P2 -end - -function DEQChain(layers...) - pre_deq, post_deq, deq, encounter_deq = [], [], nothing, false - for l in layers - if typeof(l) <: AbstractDeepEquilibriumNetwork - @assert !encounter_deq "Can have only 1 DEQ Layer in the Chain!!!" - deq = l - encounter_deq = true - continue - end - push!(encounter_deq ? post_deq : pre_deq, l) - end - @assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain" - pre_deq = length(pre_deq) == 0 ? NoOpLayer() : (length(pre_deq) == 1 ? pre_deq[1] : FChain(pre_deq...)) - post_deq = length(post_deq) == 0 ? NoOpLayer() : (length(post_deq) == 1 ? post_deq[1] : FChain(post_deq...)) - return DEQChain(pre_deq, deq, post_deq) -end - -Flux.@functor DEQChain - -function (deq::DEQChain)(x; kwargs...) - x1 = deq.pre_deq(x) - x2, deq_soln = deq.deq(x1; kwargs...) - x3 = deq.post_deq(x2) - return x3, deq_soln -end - -function get_and_clear_nfe!(model::DEQChain) - nfe = model.deq.stats.nfe - model.deq.stats.nfe = 0 - return nfe -end - -function Base.show(io::IO, model::DEQChain) - l1 = length(destructure_parameters(model)[1]) - println(io, "DEQChain(") - print(io, "\t") - show(io, model.pre_deq) - print(io, "\n\t") - show(io, model.deq) - print(io, "\n\t") - show(io, model.post_deq) - return print(io, "\n) $l1 Trainable Parameters") -end diff --git a/src/operator.jl b/src/operator.jl new file mode 100644 index 00000000..3e98099c --- /dev/null +++ b/src/operator.jl @@ -0,0 +1,25 @@ +struct ZygotePullbackMultiplyOperator{T,F,S} + f::F + s::S +end + +Base.deepcopy(op::ZygotePullbackMultiplyOperator) = op + +Base.size(z::ZygotePullbackMultiplyOperator) = (prod(z.s), prod(z.s)) +Base.size(z::ZygotePullbackMultiplyOperator, ::Int64) = prod(z.s) + +Base.eltype(::ZygotePullbackMultiplyOperator{T}) where {T} = T + +function LinearAlgebra.mul!( + du::AbstractVector, + L::ZygotePullbackMultiplyOperator, + x::AbstractVector, +) + du .= vec(L * x) +end + +function Base.:*(L::ZygotePullbackMultiplyOperator, x::AbstractVector) + return L.f(reshape(x, L.s))[1] +end + +SciMLBase.isinplace(z::ZygotePullbackMultiplyOperator, ::Int64) = false diff --git a/src/solve.jl b/src/solve.jl index fafc26ca..056ac41b 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,245 +1,10 @@ -""" - ContinuousDEQSolver(alg=VCABM4(); mode::Symbol=:rel_deq_default, abstol=1e-8, reltol=1e-8, tspan=Inf) - -Solver for Continuous DEQ Problem ([pal2022mixing](@cite)). Similar to `DynamicSS` but provides more flexibility needed -for solving DEQ problems. - -## Arguments - -* `alg`: Algorithm to solve the ODEProblem. (Default: `VCABM4()`) -* `mode`: Termination Mode of the solver. See below for a description of the various termination conditions (Default: `:rel_deq_default`) -* `abstol`: Absolute tolerance for termination. (Default: `1e-8`) -* `reltol`: Relative tolerance for termination. (Default: `1e-8`) -* `tspan`: Time span. Users should not change this value, instead control termination through `maxiters` in `solve` (Default: `Inf`) - -## Termination Modes - -#### Termination on Absolute Tolerance - -* `:abs`: Terminates if ``all \\left( | \\frac{\\partial u}{\\partial t} | \\leq abstol \\right)`` -* `:abs_norm`: Terminates if ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq abstol`` -* `:abs_deq_default`: Essentially `abs_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) -* `:abs_deq_best`: Same as `:abs_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged - -#### Termination on Relative Tolerance - -* `:rel`: Terminates if ``all \\left(| \\frac{\\partial u}{\\partial t} | \\leq reltol \\times | u | \\right)`` -* `:rel_norm`: Terminates if ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq reltol \\times \\| \\frac{\\partial u}{\\partial t} + u \\|`` -* `:rel_deq_default`: Essentially `rel_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) -* `:rel_deq_best`: Same as `:rel_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged - -#### Termination using both Absolute and Relative Tolerances - -* `:norm`: Terminates if ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq reltol \\times \\| \\frac{\\partial u}{\\partial t} + u \\|`` & - ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq abstol`` -* `fallback`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. This is usable for small problems - but doesn't scale well for neural networks, and should be avoided unless absolutely necessary - -See also: [`DiscreteDEQSolver`](@ref) - -!!! note - This will be upstreamed to DiffEqSensitivity in the later releases of the package -""" -struct ContinuousDEQSolver{M,A,AT,RT,TS} <: SteadyStateDiffEq.SteadyStateDiffEqAlgorithm - alg::A - abstol::AT - reltol::RT - tspan::TS -end - -function ContinuousDEQSolver(alg=VCABM4(); mode::Symbol=:rel_deq_default, abstol=1e-8, reltol=1e-8, tspan=Inf) - return ContinuousDEQSolver{Val(mode),typeof(alg),typeof(abstol),typeof(reltol),typeof(tspan)}(alg, abstol, reltol, tspan) -end - -function terminate_condition_reltol(integrator, abstol, reltol, min_t) - return all(abs.(DiffEqBase.get_du(integrator)) .<= reltol .* abs.(integrator.u)) -end - -function terminate_condition_reltol_norm(integrator, abstol, reltol, min_t) - du = DiffEqBase.get_du(integrator) - return norm(du) <= reltol * norm(du .+ integrator.u) -end - -function terminate_condition_abstol(integrator, abstol, reltol, min_t) - return all(abs.(DiffEqBase.get_du(integrator)) .<= abstol) -end - -function terminate_condition_abstol_norm(integrator, abstol, reltol, min_t) - return norm(DiffEqBase.get_du(integrator)) <= abstol -end - -function terminate_condition(integrator, abstol, reltol, min_t) - return all((abs.(DiffEqBase.get_du(integrator)) .<= reltol .* abs.(integrator.u)) .& - (abs.(DiffEqBase.get_du(integrator)) .<= abstol)) -end - -function terminate_condition_norm(integrator, abstol, reltol, min_t) - du = DiffEqBase.get_du(integrator) - du_norm = norm(du) - return (du_norm <= reltol * norm(du .+ integrator.u)) && (du_norm <= abstol) -end - -get_terminate_condition(::ContinuousDEQSolver{Val(:abs)}, args...; kwargs...) = terminate_condition_abstol -get_terminate_condition(::ContinuousDEQSolver{Val(:abs_norm)}, args...; kwargs...) = terminate_condition_abstol_norm -get_terminate_condition(::ContinuousDEQSolver{Val(:rel)}, args...; kwargs...) = terminate_condition_reltol -get_terminate_condition(::ContinuousDEQSolver{Val(:rel_norm)}, args...; kwargs...) = terminate_condition_reltol_norm -get_terminate_condition(::ContinuousDEQSolver{Val(:norm)}, args...; kwargs...) = terminate_condition_norm -get_terminate_condition(::ContinuousDEQSolver, args...; kwargs...) = terminate_condition - -# Termination conditions used in the original DEQ Paper -function get_terminate_condition(::ContinuousDEQSolver{Val(:abs_deq_default),A,T}, args...; kwargs...) where {A,T} - nstep = 0 - protective_threshold = T(1e6) - objective_values = T[] - function terminate_condition_closure(integrator, abstol, reltol, min_t) - du = DiffEqBase.get_du(integrator) - objective = norm(du) - # Main termination condition - objective <= abstol && return true - - # Terminate if there has been no improvement for the last 30 steps - nstep += 1 - push!(objective_values, objective) - - objective <= 3 * abstol && - nstep >= 30 && - maximum(objective_values[(end - nstep):end]) < 1.3 * minimum(objective_values[(end - nstep):end]) && - return true - - # Protective break - objective >= objective_values[1] * protective_threshold * length(du) && return true - - return false - end - return terminate_condition_closure -end - -function get_terminate_condition(::ContinuousDEQSolver{Val(:rel_deq_default),A,T}, args...; kwargs...) where {A,T} - nstep = 0 - protective_threshold = T(1e3) - objective_values = T[] - function terminate_condition_closure(integrator, abstol, reltol, min_t) - du = DiffEqBase.get_du(integrator) - u = integrator.u - objective = norm(du) / (norm(du .+ u) + eps(T)) - # Main termination condition - objective <= reltol && return true - - # Terminate if there has been no improvement for the last 30 steps - nstep += 1 - push!(objective_values, objective) - - objective <= 3 * reltol && - nstep >= 30 && - maximum(objective_values[(end - nstep + 1):end]) < 1.3 * minimum(objective_values[(end - nstep + 1):end]) && - return true - - # Protective break - objective >= objective_values[1] * protective_threshold * length(du) && return true - - return false - end - return terminate_condition_closure -end - -function get_terminate_condition(::ContinuousDEQSolver{Val(:rel_deq_best),A,T}, terminate_stats::Dict, args...; - kwargs...) where {A,T} - nstep = 0 - protective_threshold = T(1e3) - objective_values = T[] - - terminate_stats[:best_objective_value] = T(Inf) - terminate_stats[:best_objective_value_iteration] = 0 - - function terminate_condition_closure(integrator, abstol, reltol, min_t) - du = DiffEqBase.get_du(integrator) - u = integrator.u - objective = norm(du) / (norm(du .+ u) + eps(T)) - - if objective < terminate_stats[:best_objective_value] - terminate_stats[:best_objective_value] = objective - terminate_stats[:best_objective_value_iteration] = nstep + 1 - end - - # Main termination condition - objective <= reltol && return true - - # Terminate if there has been no improvement for the last 30 steps - nstep += 1 - push!(objective_values, objective) - - objective <= 3 * reltol && - nstep >= 30 && - maximum(objective_values[(end - nstep + 1):end]) < 1.3 * minimum(objective_values[(end - nstep + 1):end]) && - return true - - # Protective break - objective >= objective_values[1] * protective_threshold * length(du) && return true - - return false - end - - return terminate_condition_closure -end - -function get_terminate_condition(::ContinuousDEQSolver{Val(:abs_deq_best),A,T}, terminate_stats::Dict, args...; - kwargs...) where {A,T} - nstep = 0 - protective_threshold = T(1e3) - objective_values = T[] - - terminate_stats[:best_objective_value] = T(Inf) - terminate_stats[:best_objective_value_iteration] = 0 - - function terminate_condition_closure(integrator, abstol, reltol, min_t) - du = DiffEqBase.get_du(integrator) - objective = norm(du) - - if objective < terminate_stats[:best_objective_value] - terminate_stats[:best_objective_value] = objective - terminate_stats[:best_objective_value_iteration] = nstep + 1 - end - - # Main termination condition - objective <= reltol && return true - - # Terminate if there has been no improvement for the last 30 steps - nstep += 1 - push!(objective_values, objective) - - objective <= 3 * reltol && - nstep >= 30 && - maximum(objective_values[(end - nstep + 1):end]) < 1.3 * minimum(objective_values[(end - nstep + 1):end]) && - return true - - # Protective break - objective >= objective_values[1] * protective_threshold * length(du) && return true - - return false - end - - return terminate_condition_closure -end - -has_converged(du, u, alg::ContinuousDEQSolver) = all(abs.(du) .<= alg.abstol .& abs.(du) .<= alg.reltol .* abs.(u)) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:norm)}) = norm(du) <= alg.abstol && norm(du) <= alg.reltol * norm(du .+ u) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:rel)}) = all(abs.(du) .<= alg.reltol .* abs.(u)) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:rel_norm)}) = norm(du) <= alg.reltol * norm(du .+ u) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:rel_deq_default)}) = norm(du) <= alg.reltol * norm(du .+ u) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:rel_deq_best)}) = norm(du) <= alg.reltol * norm(du .+ u) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:abs)}) = all(abs.(du) .<= alg.abstol) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:abs_norm)}) = norm(du) <= alg.abstol -has_converged(du, u, alg::ContinuousDEQSolver{Val(:abs_deq_default)}) = norm(du) <= alg.abstol -has_converged(du, u, alg::ContinuousDEQSolver{Val(:abs_deq_best)}) = norm(du) <= alg.abstol - -struct EquilibriumSolution{T,N,uType,R,P,A,TEnd} <: SciMLBase.AbstractNonlinearSolution{T,N} +struct EquilibriumSolution{T,N,uType,P,A,D} <: SciMLBase.AbstractNonlinearSolution{T,N} u::uType - resid::R + resid::uType prob::P alg::A retcode::Symbol - t::TEnd - λₜ::T + destats::D end function transform_solution(soln::EquilibriumSolution) @@ -247,101 +12,74 @@ function transform_solution(soln::EquilibriumSolution) return DiffEqBase.build_solution(soln.prob, soln.alg, soln.u, soln.resid; retcode=soln.retcode) end -function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::ContinuousDEQSolver, args...; - regularize_endpoint=false, kwargs...) +function DiffEqBase.__solve( + prob::DiffEqBase.AbstractSteadyStateProblem{uType}, alg::ContinuousDEQSolver, args...; kwargs... +) where {uType} tspan = alg.tspan isa Tuple ? alg.tspan : convert.(real(eltype(prob.u0)), (zero(alg.tspan), alg.tspan)) _prob = ODEProblem(prob.f, prob.u0, tspan, prob.p) - terminate_stats = Dict{Symbol,Any}(:best_objective_value => real(eltype(prob.u0))(Inf), - :best_objective_value_iteration => nothing) - - sol = solve(_prob, alg.alg, args...; kwargs..., - callback=TerminateSteadyState(alg.abstol, alg.reltol, get_terminate_condition(alg, terminate_stats))) - - u, t = terminate_stats[:best_objective_value_iteration] === nothing ? (sol.u[end], sol.t[end]) : - (sol.u[terminate_stats[:best_objective_value_iteration] + 1], - sol.t[terminate_stats[:best_objective_value_iteration] + 1]) + terminate_stats = Dict{Symbol,Any}( + :best_objective_value => real(eltype(prob.u0))(Inf), :best_objective_value_iteration => nothing + ) + + sol = solve( + _prob, + alg.alg, + args...; + kwargs..., + callback=TerminateSteadyState( + alg.abstol_termination, alg.reltol_termination, get_terminate_condition(alg, terminate_stats) + ), + ) + + u, t = if terminate_stats[:best_objective_value_iteration] === nothing + (sol.u[end], sol.t[end]) + else + ( + sol.u[terminate_stats[:best_objective_value_iteration] + 1], + sol.t[terminate_stats[:best_objective_value_iteration] + 1], + ) + end + # Dont count towards NFE since this is mostly a check for convergence du = prob.f(u, prob.p, t) - T = eltype(eltype(u)) - N = ndims(u) retcode = (sol.retcode == :Terminated && has_converged(du, u, alg) ? :Success : :Failure) - _t = regularize_endpoint isa Bool ? (regularize_endpoint ? t : nothing) : t - regularize_endpoint = regularize_endpoint isa Bool ? (regularize_endpoint ? T(1e-5) : T(0)) : T(regularize_endpoint) - return EquilibriumSolution{T,N,typeof(u),typeof(du),typeof(prob),typeof(alg),typeof(_t)}(u, du, prob, alg, retcode, - _t, regularize_endpoint) + return EquilibriumSolution{eltype(uType),ndims(uType),uType,typeof(prob),typeof(alg),typeof(sol.destats)}( + u, du, prob, alg, retcode, sol.destats + ) end -function clear_zero(x::T) where T - ϵ = eps(T) - if -ϵ <= x < 0 - return -ϵ - elseif 0 <= x < ϵ - return ϵ - end - return x -end - -@noinline function DiffEqSensitivity.SteadyStateAdjointProblem(sol::EquilibriumSolution, - sensealg::DiffEqSensitivity.SteadyStateAdjoint, g, dg; - save_idxs=nothing) - @unpack f, p, u0 = sol.prob - - discrete = false +function DiffEqBase.__solve( + prob::DiffEqBase.AbstractSteadyStateProblem{uType}, alg::DiscreteDEQSolver, args...; maxiters=10, kwargs... +) where {uType} + terminate_stats = Dict{Symbol,Any}( + :best_objective_value => real(eltype(prob.u0))(Inf), :best_objective_value_iteration => nothing + ) - p === DiffEqBase.NullParameters() && - error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!") + us, stats = nlsolve( + alg.alg, + u -> prob.f(u, prob.p, nothing), + prob.u0; + maxiters=maxiters, + terminate_condition=get_terminate_condition(alg, terminate_stats) + ) - sense = DiffEqSensitivity.SteadyStateAdjointSensitivityFunction(g, sensealg, discrete, sol, dg, f.colorvec, false) - @unpack diffcache, y, sol, λ, vjp, linsolve = sense - - _save_idxs = save_idxs === nothing ? Colon() : save_idxs - if dg !== nothing - if g !== nothing - dg(vec(diffcache.dg_val), y, p, nothing, nothing) - else - if typeof(_save_idxs) <: Number - diffcache.dg_val[_save_idxs] = dg[_save_idxs] - elseif typeof(dg) <: Number - @. diffcache.dg_val[_save_idxs] = dg - else - @. diffcache.dg_val[_save_idxs] = dg[_save_idxs] - end - end + u = if terminate_stats[:best_objective_value_iteration] === nothing + us[end] else - if g !== nothing - DiffEqSensitivity.gradient!(vec(diffcache.dg_val), diffcache.g, y, sensealg, diffcache.g_grad_config) - end + us[terminate_stats[:best_objective_value_iteration] + 1] end - _val, back = Zygote.pullback(x -> f(x, p, nothing), y) - s_val = size(_val) - op = DiffEqSensitivity.ZygotePullbackMultiplyOperator{eltype(y),typeof(back),typeof(s_val)}(back, s_val) - - b = vec(diffcache.dg_val) - # println("Original Mean: $(mean(b)) & Residual Mean: $(mean(sol.resid)) Norm: $(norm(sol.resid))") - if sol.t !== nothing - @. b = (b + clamp(sol.λₜ ./ norm(sol.resid), -Inf, mean(b))) / 2 - end - # println("Updated mean: $(mean(b))") - linear_problem = LinearProblem(op, b) + # Dont count towards NFE since this is mostly a check for convergence + du = prob.f(u, prob.p, nothing) - copyto!(vec(λ), solve(linear_problem, linsolve).u) - _, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p) - vjp .= -vec(back(λ)[1]) + retcode = has_converged(du, u, alg) ? :Success : :Failure - if g !== nothing - # compute del g/del p - dg_dp_val = zero(p) - dg_dp = DiffEqSensitivity.ParamGradientWrapper(g, nothing, y) - dg_dp_config = DiffEqSensitivity.build_grad_config(sensealg, dg_dp, p, p) - DiffEqSensitivity.gradient!(dg_dp_val, dg_dp, p, sensealg, dg_dp_config) + destats = (nf=stats.nf,) - @. dg_dp_val = dg_dp_val + vjp - return dg_dp_val - else - return vjp - end + return EquilibriumSolution{eltype(uType),ndims(uType),uType,typeof(prob),typeof(alg),typeof(destats)}( + u, du, prob, alg, retcode, destats + ) end diff --git a/src/solvers/broyden.jl b/src/solvers/broyden.jl deleted file mode 100644 index c845142c..00000000 --- a/src/solvers/broyden.jl +++ /dev/null @@ -1,156 +0,0 @@ -# Broyden -## NOTE: For the time being it is better to use `LimitedMemoryBroydenSolver` -struct BroydenCache{J,F,X} - Jinv::J - fx::F - Δfx::F - fx_old::F - x::X - Δx::X - x_old::X -end - -function BroydenCache(x) - fx, Δfx, fx_old = copy(x), copy(x), copy(x) - x, Δx, x_old = copy(x), copy(x), copy(x) - Jinv = _init_identity_matrix(x) - return BroydenCache(Jinv, fx, Δfx, fx_old, x, Δx, x_old) -end - -BroydenCache(vec_length::Int, device) = BroydenCache(device(zeros(vec_length))) - -""" - BroydenSolver(; T=Float32, device, original_dims::Tuple{Int,Int}, batch_size, maxiters::Int=50, ϵ::Real=1e-6, - abstol::Union{Real,Nothing}=nothing, reltol::Union{Real,Nothing}=nothing) - -Broyden Solver ([broyden1965class](@cite)) for solving Discrete DEQs. It is recommended to use [`LimitedMemoryBroydenSolver`](@ref) for better performance. - -## Arguments - -* `T`: The type of the elements of the vectors. (Default: `Float32`) -* `device`: The device to use. Pass `gpu` to use the GPU else pass `cpu`. -* `original_dims`: Dimensions to reshape the arrays into (excluding the batch dimension). -* `batch_size`: The batch size of the problem. Your inputs can have a different batch size, but having - them match allows us to efficiently cache internal statistics without reallocation. -* `maxiters`: Maximum number of iterations to run. -* `ϵ`: Tolerance for convergence. -* `abstol`: Absolute tolerance. -* `reltol`: Relative tolerance. (This value is ignored by `BroydenSolver` at the moment) - -See also: [`LimitedMemoryBroydenSolver`](@ref) -""" -struct BroydenSolver{C<:BroydenCache,T<:Real} - cache::C - maxiters::Int - batch_size::Int - ϵ::T -end - -function BroydenSolver(; T=Float32, device, original_dims::Tuple{Int,Int}, batch_size, maxiters::Int=50, ϵ::Real=1e-6, - abstol::Union{Real,Nothing}=nothing, reltol::Union{Real,Nothing}=nothing) - ϵ = abstol !== nothing ? abstol : ϵ - - if reltol !== nothing - @warn "reltol is set to $reltol, but `BroydenSolver` ignores this value" maxlog=1 - end - - x = device(zeros(T, prod(original_dims) * batch_size)) - cache = BroydenCache(x) - - return BroydenSolver(cache, maxiters, batch_size, T(ϵ)) -end - -function (broyden::BroydenSolver{C,T})(f!, x_::AbstractVector{T}) where {C,T} - @unpack Jinv, fx, Δfx, fx_old, x, Δx, x_old = broyden.cache - if size(x) != size(x_) - # This might happen when the last batch with insufficient batch_size - # is passed. - @unpack Jinv, fx, Δfx, fx_old, x, Δx, x_old = BroydenCache(x_) - end - x .= x_ - - f!(fx, x) - _init_identity_matrix!(Jinv) - - maybe_stuck = false - max_resets = 3 - resets = 0 - - for i in 1:(broyden.maxiters) - x_old .= x - fx_old .= fx - - p = -Jinv * fx_old - - ρ, σ₂ = T(0.9), T(0.001) - - x .= x_old .+ p - f!(fx, x) - - if norm(fx, 2) ≤ ρ * norm(fx_old, 2) - σ₂ * norm(p, 2)^2 - α = T(1) - else - α = _approximate_norm_descent(f!, fx, x, p) - x .= x_old .+ α * p - f!(fx, x) - end - - Δx .= x .- x_old - Δfx .= fx .- fx_old - - maybe_stuck = all(abs.(Δx) .<= eps(T)) || all(abs.(Δfx) .<= eps(T)) - if maybe_stuck - Jinv = _init_identity_matrix(x) - resets += 1 - maybe_stuck = (resets ≤ max_resets) && maybe_stuck - else - ΔxJinv = Δx' * Jinv - Jinv .+= ((Δx .- Jinv * Δfx) ./ (ΔxJinv * Δfx)) * ΔxJinv - end - - maybe_stuck = false - - # Convergence Check - norm(Δfx, 2) ≤ broyden.ϵ && return x - end - - return x -end - -# https://doi.org/10.1080/10556780008805782 -# FIXME: We are dropping some robustness tests for now. -function _approximate_norm_descent(f!, fx::AbstractArray{T,N}, x::AbstractArray{T,N}, p; λ₀=T(1), β=T(0.5), σ₁=T(0.001), - η=T(0.1), max_iter=50) where {T,N} - λ₂, λ₁ = λ₀, λ₀ - - f!(fx, x) - fx_norm = norm(fx, 2) - - # TODO: Test NaN/Finite - # f!(fx, x .- λ₂ .* p) - # fxλp_norm = norm(fx, 2) - # TODO: nan backtrack - - j = 0 - - f!(fx, x .+ λ₂ .* p) - converged = _test_approximate_norm_descent_convergence(f!, fx, x, fx_norm, p, σ₁, λ₂, η) - - while j < max_iter && !converged - j += 1 - λ₁, λ₂ = λ₂, β * λ₂ - converged = _test_approximate_norm_descent_convergence(f!, fx, x, fx_norm, p, σ₁, λ₂, η) - end - - return λ₂ -end - -function _test_approximate_norm_descent_convergence(f!, fx, x, fx_norm, p, σ₁, λ₂, η) - f!(fx, x .+ λ₂ .* p) - n1 = norm(fx, 2) - - f!(fx, x) - n2 = norm(fx, 2) - - return n1 ≤ fx_norm - σ₁ * norm(λ₂ .* p, 2) .^ 2 + η * n2 -end diff --git a/src/solvers/continuous.jl b/src/solvers/continuous.jl new file mode 100644 index 00000000..1803b8e7 --- /dev/null +++ b/src/solvers/continuous.jl @@ -0,0 +1,40 @@ +""" + ContinuousDEQSolver(alg=VCABM3(); mode::Symbol=:rel_deq_default, abstol=1f-8, reltol=1f-8, abstol_termination=1f-8, reltol_termination=1f-8, tspan=Inf32) + +Solver for Continuous DEQ Problem ([pal2022mixing](@cite)). Similar to `DynamicSS` but provides more flexibility needed +for solving DEQ problems. + +## Arguments + +* `alg`: Algorithm to solve the ODEProblem. (Default: `VCABM3()`) +* `mode`: Termination Mode of the solver. See below for a description of the various termination conditions (Default: `:rel_deq_default`) +* `abstol`: Absolute tolerance for time stepping. (Default: `1f-8`) +* `reltol`: Relative tolerance for time stepping. (Default: `1f-8`) +* `abstol_termination`: Absolute tolerance for termination. (Default: `1f-8`) +* `reltol_termination`: Relative tolerance for termination. (Default: `1f-8`) +* `tspan`: Time span. Users should not change this value, instead control termination through `maxiters` in `solve` (Default: `Inf32`) + +See also: [`DiscreteDEQSolver`](@ref) +""" +struct ContinuousDEQSolver{M,A,T,TS} <: SteadyStateDiffEq.SteadyStateDiffEqAlgorithm + alg::A + abstol::T + reltol::T + abstol_termination::T + reltol_termination::T + tspan::TS +end + +function ContinuousDEQSolver( + alg=VCABM3(); + mode::Symbol=:rel_deq_default, + abstol::T=1.0f-8, + reltol::T=1.0f-8, + abstol_termination::T=1.0f-8, + reltol_termination::T=1.0f-8, + tspan=Inf32, +) where {T<:Number} + return ContinuousDEQSolver{Val(mode),typeof(alg),T,typeof(tspan)}( + alg, abstol, reltol, abstol_termination, reltol_termination, tspan + ) +end diff --git a/src/solvers/discrete.jl b/src/solvers/discrete.jl new file mode 100644 index 00000000..05cd9b50 --- /dev/null +++ b/src/solvers/discrete.jl @@ -0,0 +1,33 @@ +# Wrapper for Discrete DEQs +""" + DiscreteDEQSolver(alg=LimitedMemoryBroydenSolver(); mode::Symbol=:rel_deq_default, abstol_termination::T=1.0f-8, reltol_termination::T=1.0f-8) + +Solver for Discrete DEQ Problem ([baideep2019](@cite)). Similar to `SSrootfind` but provides more flexibility needed + for solving DEQ problems. + +## Arguments + +* `alg`: Algorithm to solve the Nonlinear Problem (Default: [`LimitedMemoryBroydenSolver`](@ref)) +* `mode`: Termination Mode of the solver. See below for a description of the various termination conditions (Default: `:rel_deq_default`) +* `abstol_termination`: Absolute tolerance for termination. (Default: `1f-8`) +* `reltol_termination`: Relative tolerance for termination. (Default: `1f-8`) + +See also: [`ContinuousDEQSolver`](@ref) +""" +struct DiscreteDEQSolver{M,A,T} <: SteadyStateDiffEq.SteadyStateDiffEqAlgorithm + alg::A + abstol_termination::T + reltol_termination::T +end + +function DiscreteDEQSolver( + alg=LimitedMemoryBroydenSolver(); + mode::Symbol=:rel_deq_default, + abstol_termination::T=1.0f-8, + reltol_termination::T=1.0f-8 +) where {T<:Number} + return DiscreteDEQSolver{Val(mode),typeof(alg),T}(alg, abstol_termination, reltol_termination) +end + +include("discrete/broyden.jl") +include("discrete/limited_memory_broyden.jl") diff --git a/src/solvers/discrete/broyden.jl b/src/solvers/discrete/broyden.jl new file mode 100644 index 00000000..41e5f204 --- /dev/null +++ b/src/solvers/discrete/broyden.jl @@ -0,0 +1,120 @@ +""" + BroydenSolver(; T=Float32, device, original_dims::Tuple{Int,Int}, batch_size, maxiters::Int=50, ϵ::Real=1e-6, + abstol::Union{Real,Nothing}=nothing, reltol::Union{Real,Nothing}=nothing) + +Broyden Solver ([broyden1965class](@cite)) for solving Discrete DEQs. It is recommended to use [`LimitedMemoryBroydenSolver`](@ref) for better performance. + +## Arguments + +* `T`: The type of the elements of the vectors. (Default: `Float32`) +* `device`: The device to use. Pass `gpu` to use the GPU else pass `cpu`. +* `original_dims`: Dimensions to reshape the arrays into (excluding the batch dimension). +* `batch_size`: The batch size of the problem. Your inputs can have a different batch size, but having + them match allows us to efficiently cache internal statistics without reallocation. +* `maxiters`: Maximum number of iterations to run. +* `ϵ`: Tolerance for convergence. +* `abstol`: Absolute tolerance. +* `reltol`: Relative tolerance. (This value is ignored by `BroydenSolver` at the moment) + +See also: [`LimitedMemoryBroydenSolver`](@ref) +""" +struct BroydenSolver end + +function nlsolve( + b::BroydenSolver, f::Function, y::AbstractArray{T}; terminate_condition, maxiters::Int=10 +) where {T} + res, stats = nlsolve( + b, + u -> vec(f(reshape(u, size(y)))), + vec(y); + terminate_condition, + maxiters + ) + return reshape(res, size(y)), stats +end + +function nlsolve( + ::BroydenSolver, f::Function, y::AbstractVector{T}; terminate_condition, maxiters::Int=10 +) where {T} + x = copy(y) + x_old = copy(y) + Δx = copy(y) + fx_old = f(y) + Δfx = copy(fx_old) + Jinv = _init_identity_matrix(y) + p = similar(fx_old, (size(Jinv, 1),)) + ρ, σ₂ = T(0.9), T(0.001) + + # Store the trajectory + xs = [x] + + maybe_stuck, max_resets, resets, nsteps, nf = false, 3, 0, 1, 1 + + while nsteps <= maxiters + mul!(p, Jinv, fx_old) + p .*= -1 + + @. x = x_old + p + fx = f(x) + nf += 1 + + if norm(fx, 2) ≤ ρ * norm(fx_old, 2) - σ₂ * norm(p, 2)^2 + α = T(1) + else + α, _stats = _approximate_norm_descent(f, x, p) + @. x = x_old + α * p + fx = f(x) + nf += 1 + _stats.nf + end + + @. Δx = x - x_old + @. Δfx = fx - fx_old + + maybe_stuck = all(abs.(Δx) .<= eps(T)) || all(abs.(Δfx) .<= eps(T)) + if maybe_stuck + Jinv = _init_identity_matrix(x) + resets += 1 + maybe_stuck = (resets ≤ max_resets) && maybe_stuck + else + ΔxJinv = Δx' * Jinv + Jinv .+= ((Δx .- Jinv * Δfx) ./ (ΔxJinv * Δfx)) * ΔxJinv + end + + maybe_stuck = false + nsteps += 1 + copyto!(fx_old, fx) + copyto!(x_old, x) + + push!(xs, x) + + # Convergence Check + terminate_condition(fx, x) && break + end + + return xs, (nf=nf,) +end + +function _approximate_norm_descent(f::Function, x::AbstractArray{T,N}, p; λ₀=T(1), β=T(0.5), σ₁=T(0.001), + η=T(0.1), max_iter=50) where {T,N} + λ₂, λ₁ = λ₀, λ₀ + + fx = f(x) + fx_norm = norm(fx, 2) + j = 1 + fx = f(x .+ λ₂ .* p) + converged = false + + while j <= max_iter && !converged + j += 1 + λ₁, λ₂ = λ₂, β * λ₂ + converged = _test_approximate_norm_descent_convergence(f, x, fx_norm, p, σ₁, λ₂, η) + end + + return λ₂, (nf=2(j + 1),) +end + +function _test_approximate_norm_descent_convergence(f, x, fx_norm, p, σ₁, λ₂, η) + n1 = norm(f(x .+ λ₂ .* p), 2) + n2 = norm(f(x), 2) + return n1 ≤ fx_norm - σ₁ * norm(λ₂ .* p, 2) .^ 2 + η * n2 +end diff --git a/src/solvers/discrete/limited_memory_broyden.jl b/src/solvers/discrete/limited_memory_broyden.jl new file mode 100644 index 00000000..3c84402d --- /dev/null +++ b/src/solvers/discrete/limited_memory_broyden.jl @@ -0,0 +1,103 @@ +# Limited Memory Broyden +""" + LimitedMemoryBroydenSolver(; T=Float32, device, original_dims::Tuple{Int,Int}, batch_size, maxiters::Int=50, + ϵ::Real=1e-6, criteria::Symbol=:reltol, abstol::Union{Real,Nothing}=nothing, + reltol::Union{Real,Nothing}=nothing + +Limited Memory Broyden Solver ([baimultiscale2020](@cite)) for solving Discrete DEQs. + +## Arguments + +* `T`: The type of the elements of the vectors. (Default: `Float32`) +* `device`: The device to use. Pass `gpu` to use the GPU else pass `cpu`. +* `original_dims`: Dimensions to reshape the arrays into (excluding the batch dimension). +* `batch_size`: The batch size of the problem. Your inputs can have a different batch size, but having + them match allows us to efficiently cache internal statistics without reallocation. +* `maxiters`: Maximum number of iterations to run. +* `ϵ`: Tolerance for convergence. +* `criteria`: The criteria to use for convergence. Can be `:reltol` or `:abstol`. +* `abstol`: Absolute tolerance. +* `reltol`: Relative tolerance. + +See also: [`BroydenSolver`](@ref) +""" +struct LimitedMemoryBroydenSolver end + +@inbounds @views function nlsolve( + ::LimitedMemoryBroydenSolver, f::Function, y::AbstractMatrix{T}; terminate_condition, maxiters::Int=10 +) where {T} + LBFGS_threshold = min(maxiters, 27) + + total_hsize, batch_size = size(y) + + # Initialize the cache + x₀ = copy(y) + fx₀ = f(x₀) + x₁ = copy(y) + Δx = copy(x₀) + Δfx = copy(x₀) + Us = fill!(similar(y, (LBFGS_threshold, total_hsize, batch_size)), T(0)) + VTs = fill!(similar(y, (total_hsize, LBFGS_threshold, batch_size)), T(0)) + + # Store the trajectory + xs = [x₀] + + # Counters + nstep = 1 + + # Main Algorithm + update = fx₀ + + while nstep <= maxiters + # Update + @. x₁ = x₀ + update + fx₁ = f(x₁) + @. Δx = x₁ - x₀ + @. Δfx = fx₁ - fx₀ + + push!(xs, x₁) + + # Convergence Check + terminate_condition(fx₁, x₁) && break + + # Compute the update + part_Us = Us[1:min(LBFGS_threshold, nstep), :, :] + part_VTs = VTs[:, 1:min(LBFGS_threshold, nstep), :] + + vT = rmatvec(part_Us, part_VTs, Δx) # D x C x N + mvec = matvec(part_Us, part_VTs, Δfx) + vTΔfx = sum(vT .* Δfx; dims=(1, 2)) + @. Δx = (Δx - mvec) / (vTΔfx + eps(T)) # D x C x N + + VTs[:, mod1(nstep, LBFGS_threshold), :] .= vT + Us[mod1(nstep, LBFGS_threshold), :, :] .= Δx + + update = + -matvec( + Us[1:min(LBFGS_threshold, nstep + 1), :, :], VTs[:, 1:min(LBFGS_threshold, nstep + 1), :], fx₁ + ) + copyto!(x₀, x₁) + copyto!(fx₀, fx₁) + + # Increment Counter + nstep += 1 + end + + return xs, (nf=nstep + 1,) +end + +@inbounds @views function matvec( + part_Us::AbstractArray{E,3}, part_VTs::AbstractArray{E,3}, x::AbstractArray{E,2} +) where {E} + # part_Us -> (T x D x N) | part_VTs -> (D x T x N) | x -> (D x N) + xTU = sum(unsqueeze(x; dims=1) .* part_Us; dims=2) # T x 1 x N + return -x .+ dropdims(sum(permutedims(xTU, (2, 1, 3)) .* part_VTs; dims=2); dims=2) +end + +@inbounds @views function rmatvec( + part_Us::AbstractArray{E,3}, part_VTs::AbstractArray{E,3}, x::AbstractArray{E,2} +) where {E} + # part_Us -> (T x D x N) | part_VTs -> (D x T x N) | x -> (D x N) + VTx = sum(part_VTs .* unsqueeze(x; dims=2); dims=1) # 1 x T x N + return -x .+ dropdims(sum(part_Us .* permutedims(VTx, (2, 1, 3)); dims=1); dims=1) +end diff --git a/src/solvers/limited_memory_broyden.jl b/src/solvers/limited_memory_broyden.jl deleted file mode 100644 index b0324f72..00000000 --- a/src/solvers/limited_memory_broyden.jl +++ /dev/null @@ -1,169 +0,0 @@ -# Limited Memory Broyden -struct LimitedMemoryBroydenCache{uT,vT,F,X} - Us::uT - VTs::vT - fx_::F - x::X -end - -""" - LimitedMemoryBroydenSolver(; T=Float32, device, original_dims::Tuple{Int,Int}, batch_size, maxiters::Int=50, - ϵ::Real=1e-6, criteria::Symbol=:reltol, abstol::Union{Real,Nothing}=nothing, - reltol::Union{Real,Nothing}=nothing - -Limited Memory Broyden Solver ([baimultiscale2020](@cite)) for solving Discrete DEQs. - -## Arguments - -* `T`: The type of the elements of the vectors. (Default: `Float32`) -* `device`: The device to use. Pass `gpu` to use the GPU else pass `cpu`. -* `original_dims`: Dimensions to reshape the arrays into (excluding the batch dimension). -* `batch_size`: The batch size of the problem. Your inputs can have a different batch size, but having - them match allows us to efficiently cache internal statistics without reallocation. -* `maxiters`: Maximum number of iterations to run. -* `ϵ`: Tolerance for convergence. -* `criteria`: The criteria to use for convergence. Can be `:reltol` or `:abstol`. -* `abstol`: Absolute tolerance. -* `reltol`: Relative tolerance. - -See also: [`BroydenSolver`](@ref) -""" -struct LimitedMemoryBroydenSolver{C<:LimitedMemoryBroydenCache,RT<:Union{AbstractFloat,Nothing}, - AT<:Union{AbstractFloat,Nothing}} - cache::C - original_dims::Tuple{Int,Int} - maxiters::Int - batch_size::Int - criteria::Symbol - reltol::RT - abstol::AT -end - -function LimitedMemoryBroydenSolver(; T=Float32, device, original_dims::Tuple{Int,Int}, batch_size, maxiters::Int=50, - ϵ::Real=1e-6, criteria::Symbol=:reltol, abstol::Union{Real,Nothing}=nothing, - reltol::Union{Real,Nothing}=nothing) - @assert criteria ∈ (:abstol, :reltol) - - abstol = abstol !== nothing ? T(abstol) : T(ϵ) - reltol = reltol !== nothing ? T(reltol) : T(ϵ) - - LBFGS_threshold = min(maxiters, 27) - - x = device(zeros(T, original_dims..., batch_size)) - fx = device(zeros(T, original_dims..., batch_size)) - - total_hsize, n_elem, batch_size = size(x) - - # L x 2D x C x N - Us = fill!(similar(x, (LBFGS_threshold, total_hsize, n_elem, batch_size)), T(0)) - # 2D x C x L x N - VTs = fill!(similar(x, (total_hsize, n_elem, LBFGS_threshold, batch_size)), T(0)) - - cache = LimitedMemoryBroydenCache(Us, VTs, vec(fx), x) - - return LimitedMemoryBroydenSolver(cache, original_dims, maxiters, batch_size, criteria, reltol, abstol) -end - -function line_search(update, x₀, f₀, f, nstep::Int=0, on::Bool=false) - # TODO: Implement a line search algorithm - x_est = x₀ .+ update - f₀_new = f(x_est) - return (x_est, f₀_new, x_est .- x₀, f₀_new .- f₀, 0) -end - -function (lbroyden::LimitedMemoryBroydenSolver{C,T})(f!, x_::AbstractVector{T}) where {C,T} - @unpack cache, original_dims, batch_size, maxiters, criteria, reltol, abstol = lbroyden - ϵ = getfield(lbroyden, criteria) - - nfeatures = prod(original_dims) - if nfeatures * batch_size != length(x_) - # Maybe the last batch is smaller than the others - cache = LimitedMemoryBroydenSolver(; T=T, device=x_ isa CuArray ? gpu : cpu, original_dims=original_dims, - batch_size=length(x_) ÷ nfeatures, maxiters=maxiters, ϵ=ϵ).cache - end - - @unpack Us, VTs, fx_, x = cache - x .= reshape(x_, size(x)) - LBFGS_threshold = size(Us, 1) - fill!(Us, T(0)) - fill!(VTs, T(0)) - - # Counters - nstep = 1 - tnstep = 1 - - # Initialize - total_hsize, n_elem, batch_size = actual_size = size(x) - - # Modify the functions - f(x) = (f!(fx_, vec(x)); return reshape(fx_, actual_size)) - fx = f(x) - - update = fx - new_objective = norm(fx) - objective_values = [new_objective] - - protect_threshold = (criteria == :abstol ? T(1e6) : T(1e3)) * n_elem - initial_objective = new_objective - lowest_objective = new_objective - lowest_xest = x - - @inbounds while nstep < maxiters - x, fx, Δx, Δfx, ite = line_search(update, x, fx, f, nstep, false) - nstep += 1 - tnstep += (ite + 1) - - new_objective = criteria == :abstol ? norm(fx) : (norm(fx) / (norm(fx .+ x) + eps(T))) - push!(objective_values, new_objective) - - if new_objective < lowest_objective - lowest_objective = new_objective - lowest_xest = x - end - new_objective < ϵ && break - - new_objective < 3ϵ && - nstep >= 30 && - maximum(objective_values[(end - nstep + 1):end]) < 1.3 * minimum(objective_values[(end - nstep + 1):end]) && - break - - # Prevent Divergence - (new_objective > initial_objective * protect_threshold) && break - - @views part_Us = Us[1:min(LBFGS_threshold, nstep), :, :, :] - @views part_VTs = VTs[:, :, 1:min(LBFGS_threshold, nstep), :] - - vT = rmatvec(part_Us, part_VTs, Δx) # 2D x C x N - u = (Δx .- matvec(part_Us, part_VTs, Δfx)) ./ sum(vT .* Δfx; dims=(1, 2)) # 2D x C x N - vT[.!isfinite.(vT)] .= T(0) - u[.!isfinite.(u)] .= T(0) - - @views VTs[:, :, mod1(nstep, LBFGS_threshold), :] .= vT - @views Us[mod1(nstep, LBFGS_threshold), :, :, :] .= u - - @views update = -matvec(Us[1:min(LBFGS_threshold, nstep + 1), :, :, :], - VTs[:, :, 1:min(LBFGS_threshold, nstep + 1), :], fx) - end - - return vec(lowest_xest) -end - -function matvec(part_Us::AbstractArray{E,4}, part_VTs::AbstractArray{E,4}, x::AbstractArray{E,3}) where {E} - # part_Us -> (T x D x C x N) - # part_VTs -> (D x C x T x N) - # x -> (D x C x N) - length(part_Us) == 0 && return -x - T, D, C, N = size(part_Us) - xTU = sum(reshape(x, (1, D, C, N)) .* part_Us; dims=(2, 3)) # T x 1 x 1 x N - return -x .+ reshape(sum(permutedims(xTU, (2, 3, 1, 4)) .* part_VTs; dims=3), (D, C, N)) -end - -function rmatvec(part_Us::AbstractArray{E,4}, part_VTs::AbstractArray{E,4}, x::AbstractArray{E,3}) where {E} - # part_Us -> (T x D x C x N) - # part_VTs -> (D x C x T x N) - # x -> (D x C x N) - length(part_Us) == 0 && return -x - T, D, C, N = size(part_Us) - VTx = sum(part_VTs .* reshape(x, (D, C, 1, N)); dims=(1, 2)) # 1 x 1 x T x N - return -x .+ reshape(sum(part_Us .* permutedims(VTx, (3, 1, 2, 4)); dims=1), (D, C, N)) -end diff --git a/src/solvers/termination.jl b/src/solvers/termination.jl new file mode 100644 index 00000000..31f21cc2 --- /dev/null +++ b/src/solvers/termination.jl @@ -0,0 +1,138 @@ +get_mode(::Val{mode}) where {mode} = mode + +function get_terminate_condition(alg::ContinuousDEQSolver{M,A,T}, args...; kwargs...) where {M,A,T} + mode = get_mode(M) + if mode ∈ (:abs_deq_default, :rel_deq_default, :abs_deq_best, :rel_deq_best) + nstep, protective_threshold, objective_values = 0, T(1e3), T[] + + if mode ∈ (:rel_deq_best, :abs_deq_best) + @assert length(args) == 1 + + args[1][:best_objective_value] = T(Inf) + args[1][:best_objective_value_iteration] = 0 + end + + function terminate_condition_closure_1(integrator, abstol, reltol, min_t) + du, u = DiffEqBase.get_du(integrator), integrator.u + objective = norm(du) / (mode ∈ (:abs_deq_default, :abs_deq_best) ? 1 : (norm(du .+ u) + eps(T))) + criteria = mode ∈ (:abs_deq_default, :abs_deq_best) ? abstol : reltol + + if mode ∈ (:rel_deq_best, :abs_deq_best) + if objective < args[1][:best_objective_value] + args[1][:best_objective_value] = objective + args[1][:best_objective_value_iteration] = nstep + 1 + end + end + + # Main Termination Criteria + objective <= criteria && return true + + # Terminate if there has been no improvement for the last 30 steps + nstep += 1 + push!(objective_values, objective) + + objective <= 3 * criteria && + nstep >= 30 && + maximum(objective_values[max(1, length(objective_values) - nstep):end]) < + 1.3 * minimum(objective_values[max(1, length(objective_values) - nstep):end]) && + return true + + # Protective break + objective >= objective_values[1] * protective_threshold * length(du) && return true + + return false + end + return terminate_condition_closure_1 + else + function terminate_condition_closure_2(integrator, abstol, reltol, min_t) + return has_converged(DiffEqBase.get_du(integrator), integrator.u, M, abstol, reltol) + end + return terminate_condition_closure_2 + end +end + +function get_terminate_condition(alg::DiscreteDEQSolver{M,A,T}, args...; kwargs...) where {M,A,T} + mode = get_mode(M) + if mode ∈ (:abs_deq_default, :rel_deq_default, :abs_deq_best, :rel_deq_best) + nstep, protective_threshold, objective_values = 0, T(1e3), T[] + + if mode ∈ (:rel_deq_best, :abs_deq_best) + @assert length(args) == 1 + + args[1][:best_objective_value] = T(Inf) + args[1][:best_objective_value_iteration] = 0 + end + + function terminate_condition_closure_1(du, u) + objective = norm(du) / (mode ∈ (:abs_deq_default, :abs_deq_best) ? 1 : (norm(du .+ u) + eps(T))) + criteria = mode ∈ (:abs_deq_default, :abs_deq_best) ? alg.abstol_termination : alg.reltol_termination + + if mode ∈ (:rel_deq_best, :abs_deq_best) + if objective < args[1][:best_objective_value] + args[1][:best_objective_value] = objective + args[1][:best_objective_value_iteration] = nstep + 1 + end + end + + # Main Termination Criteria + objective <= criteria && return true + + # Terminate if there has been no improvement for the last 30 steps + nstep += 1 + push!(objective_values, objective) + + objective <= 3 * criteria && + nstep >= 30 && + maximum(objective_values[max(1, length(objective_values) - nstep):end]) < + 1.3 * minimum(objective_values[max(1, length(objective_values) - nstep):end]) && + return true + + # Protective break + objective >= objective_values[1] * protective_threshold * length(du) && return true + + return false + end + return terminate_condition_closure_1 + else + function terminate_condition_closure_2(du, u) + return has_converged(du, u, M, alg.abstol_termination, alg.reltol_termination) + end + return terminate_condition_closure_2 + end +end + +# Convergence Criterions +@inline function has_converged( + du, + u, + alg::Union{ContinuousDEQSolver{M},DiscreteDEQSolver{M}}, + abstol=alg.abstol_termination, + reltol=alg.reltol_termination, +) where {M} + return has_converged(du, u, M, abstol, reltol) +end + +@inline @inbounds function has_converged(du, u, M, abstol, reltol) + mode = get_mode(M) + if mode == :norm + return norm(du) <= abstol && norm(du) <= reltol * norm(du + u) + elseif mode == :rel + return all(abs.(du) .<= reltol .* abs.(u)) + elseif mode == :rel_norm + return norm(du) <= reltol * norm(du + u) + elseif mode == :rel_deq_default + return norm(du) <= reltol * norm(du + u) + elseif mode == :rel_deq_best + return norm(du) <= reltol * norm(du + u) + elseif mode == :abs + return all(abs.(du) .<= abstol) + elseif mode == :abs_norm + return norm(du) <= abstol + elseif mode == :abs_deq_default + return norm(du) <= abstol + elseif mode == :abs_deq_best + return norm(du) <= abstol + else + return all(abs.(du) .<= abstol .& abs.(du) .<= reltol .* abs.(u)) + end +end diff --git a/src/utils.jl b/src/utils.jl index cbcbf80e..70bacf40 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,24 +1,10 @@ # General DEQ Utils -mutable struct DEQTrainingStats - nfe::Int -end - -""" - get_and_clear_nfe!(model::AbstractDeepEquilibriumNetwork) - -Return the number of function evaluations (NFE) and clear the counter. """ -function get_and_clear_nfe!(model::AbstractDeepEquilibriumNetwork) - nfe = model.stats.nfe - model.stats.nfe = 0 - return nfe -end - -""" - SteadyStateAdjoint(reltol, abstol, maxiters; autojacvec=ZygoteVJP(), - linsolve=KrylovJL_GMRES(; rtol=reltol, atol=abstol, itmax=maxiters)) + DeepEquilibriumAdjoint(reltol, abstol, maxiters; autojacvec=ZygoteVJP(), + linsolve=KrylovJL_GMRES(; rtol=reltol, atol=abstol, itmax=maxiters), + mode=:vanilla) -Creates SteadyStateAdjoint ([johnson2012notes](@cite)) with sensible defaults. +Creates DeepEquilibriumAdjoint ([johnson2012notes](@cite)) with sensible defaults. ## Arguments @@ -27,10 +13,32 @@ Creates SteadyStateAdjoint ([johnson2012notes](@cite)) with sensible defaults. * `maxiters`: Maximum number of iterations. * `autojacvec`: Which backend to use for VJP. * `linsolve`: Linear Solver from [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl). +* `mode`: Adjoint mode. Currently only `:vanilla` & `:jfb` are supported. """ -function DiffEqSensitivity.SteadyStateAdjoint(reltol, abstol, maxiters; autojacvec=ZygoteVJP(), - linsolve=KrylovJL_GMRES(; rtol=reltol, atol=abstol, itmax=maxiters)) - return SteadyStateAdjoint(; autodiff=true, autojacvec=autojacvec, linsolve=linsolve) +struct DeepEquilibriumAdjoint{CS,AD,FDT,M,VJP,LS} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} + autojacvec::VJP + linsolve::LS +end + +@inline check_adjoint_mode(::DeepEquilibriumAdjoint{CS,AD,FDT,M}, ::Val{M}) where {CS,AD,FDT,M} = true +@inline check_adjoint_mode(::DeepEquilibriumAdjoint, ::Val) = false + +Base.@pure function DeepEquilibriumAdjoint( + reltol, + abstol, + maxiters; + autojacvec=ZygoteVJP(), + linsolve=KrylovJL_GMRES(; rtol=reltol, atol=abstol, itmax=maxiters), + autodiff=true, + chunk_size=0, + diff_type=Val{:central}, + mode::Symbol=:vanilla, +) + return DeepEquilibriumAdjoint{ + chunk_size,autodiff,diff_type,mode,typeof(autojacvec),typeof(linsolve) + }( + autojacvec, linsolve + ) end # Initialization @@ -40,67 +48,21 @@ end Initializes the weights of the network with a normal distribution. For DEQs the training is stable if we use this as the Initialization """ -function NormalInitializer(μ = 0.0f0, σ² = 0.01f0) - return (dims...) -> randn(dims...) .* σ² .+ μ -end - -# Wrapper for Discrete DEQs -""" - DiscreteDEQSolver(solver=LimitedMemoryBroydenSolver; abstol=1e-8, reltol=1e-8, kwargs...) - -Solver for Discrete DEQ Problem ([baideep2019](@cite)). A wrapper around `SSrootfind` to mimic the [`ContinuousDEQSolver`](@ref) API. - -## Arguments - -* `solver`: NonLinear Solver for the DEQ problem. (Default: [`LimitedMemoryBroydenSolver`](@ref)) -* `abstol`: Absolute tolerance for termination. (Default: `1e-8`) -* `reltol`: Relative tolerance for termination. (Default: `1e-8`) -* `kwargs`: Additional keyword arguments passed to the solver. - -!!! note - There is no `mode` kwarg for [`DiscreteDEQSolver`](@ref). Instead solvers directly define their own termination condition. - For [`BroydenSolver`](@ref) and [`LimitedMemoryBroydenSolver`](@ref), the termination conditions are `:abs_norm` & - `:rel_deq_default` respectively. - -See also: [`ContinuousDEQSolver`](@ref) -""" -function DiscreteDEQSolver(solver=LimitedMemoryBroydenSolver; abstol=1e-8, reltol=1e-8, kwargs...) - solver = solver(; kwargs..., reltol=reltol, abstol=abstol) - return SSRootfind(; nlsolve=(f, u0, abstol) -> solver(f, u0)) +function NormalInitializer(μ=0.0f0, σ²=0.01f0) + return (rng::AbstractRNG, dims...) -> randn(rng, Float32, dims...) .* σ² .+ μ end # For MultiScale DEQs -function split_array_by_indices(x::AbstractVector, idxs) - return collect((x[(i + 1):j] for (i, j) in zip(idxs[1:(end - 1)], idxs[2:end]))) -end - -function split_array_by_indices(x::AbstractMatrix, idxs) - return collect((x[(i + 1):j, :] for (i, j) in zip(idxs[1:(end - 1)], idxs[2:end]))) -end - -Zygote.@adjoint function split_array_by_indices(x, idxs) - res = split_array_by_indices(x, idxs) - function split_array_by_indices_sensitivity(Δ) - is_nothings = Δ .=== nothing - if any(is_nothings) - Δ[is_nothings] .= zero.(res[is_nothings]) - end - return (vcat(Δ...), nothing) +@generated function split_and_reshape(x::AbstractMatrix, ::T, ::S) where {T,S} + idxs, shapes = known(T), known(S) + dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)] + varnames = [gensym("x_view") for _ in dims] + calls = [] + for (i, dim) in enumerate(dims) + push!(calls, :($(varnames[i]) = view(x, $dim, :))) end - return res, split_array_by_indices_sensitivity -end - -# Zygote Fix -function Zygote.accum(x::NTuple{N,T}, y::AbstractVector{T}) where {N,T<:AbstractArray} - return Zygote.accum.(x, y) -end - -function Zygote.accum(x::AbstractVector{T}, y::NTuple{N,T}) where {N,T<:AbstractArray} - return Zygote.accum.(x, y) -end - -function Zygote.accum(x::AbstractVector{T}, y::NTuple{N,Nothing}) where {N,T<:AbstractArray} - return Zygote.accum.(x, y) + push!(calls, :(return tuple($(Tuple(varnames)...)))) + return Expr(:block, calls...) end # General Utils @@ -111,27 +73,12 @@ end @inline function _init_identity_matrix!(x::AbstractMatrix{T}, scale::T=T(1)) where {T} x .= zero(T) - idxs = diagind(x) - @. @view(x[idxs]) = scale * true + view(x, diagind(x)) .= scale .* true return x end -@inline function _norm(x; dims=Colon()) - return sqrt.(sum(abs2, x; dims=dims)) -end +@inline _norm(x; dims=Colon()) = sqrt.(sum(abs2, x; dims=dims)) # Compute norm over all dimensions except `except_dim` -@inline function _norm(x::AbstractArray{T,N}, except_dim) where {T,N} - dims = filter(i -> i != except_dim, 1:N) - return _norm(x; dims=dims) -end - -flatten_merge(x, y) = (x..., y...) -flatten_merge(x::T, y::T) where {T<:AbstractArray} = (x, y) -flatten_merge(x::NTuple{N,T}, y::T) where {N,T<:AbstractArray} = (x..., y) -flatten_merge(x::T, y::NTuple{N,T}) where {N,T<:AbstractArray} = (x, y...) -flatten_merge(x::NTuple{N,T}, y) where {N,T<:AbstractArray} = (x, y...) -flatten_merge(x, y::NTuple{N,T}) where {N,T<:AbstractArray} = (x..., y) -function flatten_merge(x::NTuple{N,T}, y::NTuple{N,T}) where {N,T<:AbstractArray} - return (x, y) -end +@inline _norm(x::AbstractArray{T,N}, except_dim) where {T,N} = + _norm(x; dims=filter(i -> i != except_dim, 1:N)) diff --git a/test/runtests.jl b/test/runtests.jl index 41b61793..1d0f4189 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,156 +1,278 @@ -using FastDEQ -using CUDA -using Flux -using FluxExperimental -using LinearAlgebra -using Random -using Test +using CUDA, FastDEQ, Functors, LinearAlgebra, Lux, Random, Test, Zygote + +function test_gradient_isfinite(gs::NamedTuple) + gradient_is_finite = [true] + function is_gradient_finite(x) + if !isnothing(x) && !all(isfinite, x) + gradient_is_finite[1] = false + end + return x + end + fmap(is_gradient_finite, gs) + return gradient_is_finite[1] +end @testset "FastDEQ.jl" begin - mse_loss_function = SupervisedLossContainer(loss_function = Flux.Losses.mse) + seed = 0 + rng = Random.default_rng() + Random.seed!(rng, seed) @info "Testing DEQ" - Random.seed!(0) - - model = gpu(DEQChain(Dense(2, 2), - DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; bias=false), Dense(2, 2; bias=false)), - ContinuousDEQSolver(;abstol=0.1f0, reltol=0.1f0)))) - x = gpu(rand(Float32, 2, 1)) - y = gpu(rand(Float32, 2, 1)) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) - end + model = DEQChain( + Dense(2, 2), + DeepEquilibriumNetwork( + Parallel(+, Dense(2, 2; bias=false), Dense(2, 2; bias=false)), + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), + ), + ) + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) + + gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] + + @test test_gradient_isfinite(gs) + + @info "Testing DEQ without Fixed Point Iterations" + st = Lux.update_state(st, :fixed_depth, Val(5)) + + gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] + + @test test_gradient_isfinite(gs) @info "Testing SkipDEQ" - Random.seed!(0) - - model = gpu(DEQChain(Dense(2, 2), - SkipDeepEquilibriumNetwork(Parallel(+, Dense(2, 2), Dense(2, 2)), - Dense(2, 2), - ContinuousDEQSolver(;abstol=0.1f0, reltol=0.1f0); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10)))) - x = gpu(rand(Float32, 2, 1)) - y = gpu(rand(Float32, 2, 1)) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) - end + Random.seed!(rng, seed) + model = DEQChain( + Dense(2, 2), + SkipDeepEquilibriumNetwork( + Parallel(+, Dense(2, 2), Dense(2, 2)), + Dense(2, 2), + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), + ), + ) + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) - @info "Testing SkipDEQ V2" - Random.seed!(0) - - model = gpu(DEQChain(Dense(2, 2), - SkipDeepEquilibriumNetwork(Parallel(+, Dense(2, 2), Dense(2, 2)), - ContinuousDEQSolver(;abstol=0.1f0, reltol=0.1f0)))) - x = gpu(rand(Float32, 2, 1)) - y = gpu(rand(Float32, 2, 1)) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) - end + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] - @info "Testing Broyden Solver" - Random.seed!(0) - - model = gpu(DEQChain(Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - SkipDeepEquilibriumNetwork(Parallel(+, Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - Conv((3, 3), 1 => 1, relu; pad=1, stride=1)), - Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - DiscreteDEQSolver(BroydenSolver; abstol=0.001f0, - reltol=0.001f0, device=gpu, original_dims=(8 * 8, 1), - batch_size=4); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10)))) - x = gpu(rand(Float32, 8, 8, 1, 4)) - y = gpu(rand(Float32, 8, 8, 1, 4)) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) - end + @test test_gradient_isfinite(gs) - @info "Testing L-Broyden Solver" - Random.seed!(0) - - model = gpu(DEQChain(Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - SkipDeepEquilibriumNetwork(Parallel(+, Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - Conv((3, 3), 1 => 1, relu; pad=1, stride=1)), - Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - DiscreteDEQSolver(LimitedMemoryBroydenSolver; abstol=0.001f0, - reltol=0.001f0, device=gpu, original_dims=(8 * 8, 1), - batch_size=4); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10)))) - x = gpu(rand(Float32, 8, 8, 1, 4)) - y = gpu(rand(Float32, 8, 8, 1, 4)) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) - end + @info "Testing SkipDEQ without Fixed Point Iterations" + st = Lux.update_state(st, :fixed_depth, Val(5)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing SkipDEQV2" + Random.seed!(rng, seed) + model = DEQChain( + Dense(2, 2), + SkipDeepEquilibriumNetwork( + Parallel(+, Dense(2, 2), Dense(2, 2)), + nothing, + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), + ), + ) + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing SkipDEQV2 without Fixed Point Iterations" + st = Lux.update_state(st, :fixed_depth, Val(5)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing SkipDEQ with Broyden Solver" + Random.seed!(rng, seed) + model = DEQChain( + Dense(2, 2), + SkipDeepEquilibriumNetwork( + Parallel(+, Dense(2, 2), Dense(2, 2)), + Dense(2, 2), + DiscreteDEQSolver(BroydenSolver(); abstol_termination=0.1f0, reltol_termination=0.1f0); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), + ), + ) + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing SkipDEQ with L-Broyden Solver" + Random.seed!(rng, seed) + model = DEQChain( + Dense(2, 2), + SkipDeepEquilibriumNetwork( + Parallel(+, Dense(2, 2), Dense(2, 2)), + Dense(2, 2), + DiscreteDEQSolver(LimitedMemoryBroydenSolver(); abstol_termination=0.1f0, reltol_termination=0.1f0); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), + ), + ) + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) @info "Testing MultiScaleDEQ" - Random.seed!(0) - - model = gpu(MultiScaleDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh_fast), Dense(4, 4, tanh_fast)), - Dense(3, 3, tanh_fast), Dense(2, 2, tanh_fast), - Dense(1, 1, tanh_fast)), - [NoOpLayer() Dense(4, 3, tanh_fast) Dense(4, 2, tanh_fast) Dense(4, 1, tanh_fast); - Dense(3, 4, tanh_fast) NoOpLayer() Dense(3, 2, tanh_fast) Dense(3, 1, tanh_fast); - Dense(2, 4, tanh_fast) Dense(2, 3, tanh_fast) NoOpLayer() Dense(2, 1, tanh_fast); - Dense(1, 4, tanh_fast) Dense(1, 3, tanh_fast) Dense(1, 2, tanh_fast) NoOpLayer()], - ContinuousDEQSolver(;abstol=0.1f0, reltol=0.1f0))) - x = gpu(rand(Float32, 4, 2)) - y = tuple([gpu(rand(Float32, i, 2)) for i in 4:-1:1]...) - sol = model(x) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) - end + Random.seed!(rng, seed) + model = MultiScaleDeepEquilibriumNetwork( + ( + Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh), + ), + [ + NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() + ], + nothing, + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), + ) + + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 4, 2)) + y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing MultiScaleDEQ without Fixed Point Iterations" + st = Lux.update_state(st, :fixed_depth, Val(5)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + end[1] + + @test test_gradient_isfinite(gs) @info "Testing MultiScaleSkipDEQ" - Random.seed!(0) - - model = gpu(MultiScaleSkipDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh_fast), Dense(4, 4, tanh_fast)), - Dense(3, 3, tanh_fast), Dense(2, 2, tanh_fast), - Dense(1, 1, tanh_fast)), - [NoOpLayer() Dense(4, 3, tanh_fast) Dense(4, 2, tanh_fast) Dense(4, 1, tanh_fast); - Dense(3, 4, tanh_fast) NoOpLayer() Dense(3, 2, tanh_fast) Dense(3, 1, tanh_fast); - Dense(2, 4, tanh_fast) Dense(2, 3, tanh_fast) NoOpLayer() Dense(2, 1, tanh_fast); - Dense(1, 4, tanh_fast) Dense(1, 3, tanh_fast) Dense(1, 2, tanh_fast) NoOpLayer()], - (Dense(4, 4, tanh_fast), Dense(4, 3, tanh_fast), - Dense(4, 2, tanh_fast), Dense(4, 1, tanh_fast)), - ContinuousDEQSolver(;abstol=0.1f0, reltol=0.1f0))) - x = gpu(rand(Float32, 4, 2)) - y = tuple([gpu(rand(Float32, i, 2)) for i in 4:-1:1]...) - sol = model(x) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) - end + Random.seed!(rng, seed) + model = MultiScaleSkipDeepEquilibriumNetwork( + ( + Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh), + ), + [ + NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() + ], + nothing, + (Dense(4, 4, tanh), Dense(4, 3, tanh), Dense(4, 2, tanh), Dense(4, 1, tanh)), + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), + ) + + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 4, 2)) + y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing MultiScaleSkipDEQ without Fixed Point Iterations" + st = Lux.update_state(st, :fixed_depth, Val(5)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing MultiScaleSkipDEQV2" + Random.seed!(rng, seed) + model = MultiScaleSkipDeepEquilibriumNetwork( + ( + Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh), + ), + [ + NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() + ], + nothing, + nothing, + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), + ) + + ps, st = gpu.(Lux.setup(rng, model)) + x = gpu(rand(rng, Float32, 4, 2)) + y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] + + @test test_gradient_isfinite(gs) + + @info "Testing MultiScaleSkipDEQV2 without Fixed Point Iterations" + st = Lux.update_state(st, :fixed_depth, Val(5)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end[1] - # CI gives mutation error though it works locally. - # @info "Testing MultiScaleSkipDEQV2" - # Random.seed!(0) - - # model = gpu(MultiScaleSkipDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh_fast), Dense(4, 4, tanh_fast)), - # Dense(3, 3, tanh_fast), Dense(2, 2, tanh_fast), - # Dense(1, 1, tanh_fast)), - # [NoOpLayer() Dense(4, 3, tanh_fast) Dense(4, 2, tanh_fast) Dense(4, 1, tanh_fast); - # Dense(3, 4, tanh_fast) NoOpLayer() Dense(3, 2, tanh_fast) Dense(3, 1, tanh_fast); - # Dense(2, 4, tanh_fast) Dense(2, 3, tanh_fast) NoOpLayer() Dense(2, 1, tanh_fast); - # Dense(1, 4, tanh_fast) Dense(1, 3, tanh_fast) Dense(1, 2, tanh_fast) NoOpLayer()], - # ContinuousDEQSolver(;abstol=0.1f0, reltol=0.1f0))) - # x = gpu(rand(Float32, 4, 2)) - # y = tuple([gpu(rand(Float32, i, 2)) for i in 4:-1:1]...) - # sol = model(x) - # ps = Flux.params(model) - # gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - # for _p in ps - # @test all(isfinite.(gs[_p])) - # end + @test test_gradient_isfinite(gs) end