diff --git a/docs/pages.jl b/docs/pages.jl index 3ce04508..7bea2ba9 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -1,12 +1,11 @@ pages = [ "Home" => "index.md", - "sciml.md", + "tutorials/getting_started.md", + "Tutorials" => Any[ + "tutorials/operator_algebras.md", + "FFT Tutorial" => "tutorials/fftw.md" + ], "interface.md", "Premade Operators" => "premade_operators.md", - "Tutorials" => Any["FFT Tutorial" => "tutorials/fftw.md" - # "tutorials/linear.md", - # "tutorials/nonlin.md", - # "tutorials/ode.md", - # "tutorials/lux.md", - ] + ] diff --git a/docs/src/index.md b/docs/src/index.md index 22990460..7e171d21 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,4 +1,4 @@ -# SciMLOperators.jl: Unified operator interface for `SciML.ai` and beyond +# SciMLOperators.jl: Unified operator interface for Julia and SciML `SciMLOperators` is a package for managing linear, nonlinear, time-dependent, and parameter dependent operators acting on vectors, @@ -25,69 +25,51 @@ using Pkg Pkg.add("SciMLOperators") ``` -## Examples - -Let `M`, `D`, `F` be matrix-based, diagonal-matrix-based, and function-based -`SciMLOperators` respectively. - -```julia -N = 4 -f = (v, u, p, t) -> u .* v - -M = MatrixOperator(rand(N, N)) -D = DiagonalOperator(rand(N)) -F = FunctionOperator(f, zeros(N), zeros(N)) -``` - -Then, the following codes just work. - -```julia -L1 = 2M + 3F + LinearAlgebra.I + rand(N, N) -L2 = D * F * M' -L3 = kron(M, D, F) -L4 = M \ D -L5 = [M; D]' * [M F; F D] * [F; D] -``` - -Each `L#` can be applied to `AbstractVector`s of appropriate sizes: - -```julia -p = nothing # parameter struct -t = 0.0 # time - -u = rand(N) -v = rand(N) -w = L1(v, u, p, t) # == L1 * v - -v_kron = rand(N^3) -w_kron = L3(v_kron, u, p, t) # == L3 * v_kron -``` - -For mutating operator evaluations, call `cache_operator` to generate an -in-place cache, so the operation is nonallocating. - -```julia -α, β = rand(2) - -# allocate cache -L2 = cache_operator(L2, u) -L4 = cache_operator(L4, u) - -# allocation-free evaluation -L2(w, v, u, p, t) # == mul!(w, L2, v) -L4(w, v, u, p, t, α, β) # == mul!(w, L4, v, α, β) -``` - -The calling signature `L(v, u, p, t)`, for out-of-place evaluations, is -equivalent to `L * v`, and the in-place evaluation `L(w, v, u, p, t, args...)` -is equivalent to `LinearAlgebra.mul!(w, L, v, args...)`, where the arguments -`u, p, t` are passed to `L` to update its state. More details are provided -in the operator update section below. - -The `(v, u, p, t)` calling signature is standardized over the `SciML` -ecosystem and is flexible enough to support use cases such as time-evolution -in ODEs, as well as sensitivity computation with respect to the parameter -object `p`. +## Why `SciMLOperators`? + +Many functions, from linear solvers to differential equations, require +the use of matrix-free operators to achieve maximum performance in +many scenarios. `SciMLOperators.jl` defines the abstract interface for how +operators in the SciML ecosystem are supposed to be defined. It gives the +common set of functions and traits that solvers can rely on for properly +performing their tasks. Along with that, `SciMLOperators.jl` provides +definitions for the basic standard operators that are used as building +blocks for most tasks, simplifying the use of operators while also +demonstrating to users how such operators can be built and used in practice. + +`SciMLOperators.jl` has the design that is required to be used in +all scenarios of equation solvers. For example, Magnus integrators for +differential equations require defining an operator ``u' = A(t) u``, while +Munthe-Kaas methods require defining operators of the form ``u' = A(u) u``. +Thus, the operators need some form of time and state dependence, which the +solvers can update and query when they are non-constant +(`update_coefficients!`). Additionally, the operators need the ability to +act like “normal” functions for equation solvers. For example, if `A(v,u,p,t)` +has the same operation as `update_coefficients(A, u, p, t); A * v`, then `A` +can be used in any place where a differential equation definition +`(u,p,t) -> A(u, u, p, t)` is used without requiring the user or solver to do any extra +work. + +Another example is state-dependent mass matrices. `M(u,p,t)*u' = f(u,p,t)`. +When solving such an equation, the solver must understand how to "update M" +during operations, and thus the ability to update the state of `M` is a required +function in the interface. This is also required for the definition of Jacobians +`J(u,p,t)` in order to be properly used with Krylov methods inside of ODE solves +without reconstructing the matrix-free operator at each step. + +Thus while previous good efforts for matrix-free operators have existed +in the Julia ecosystem, such as +[LinearMaps.jl](https://github.com/JuliaLinearAlgebra/LinearMaps.jl), those +operator interfaces lack these aspects to actually be fully seamless +with downstream equation solvers. This necessitates the definition and use of +an extended operator interface with all of these properties, hence the +`AbstractSciMLOperator` interface. + +!!! warn + + This means that LinearMaps.jl is fundamentally lacking and is incompatible + with many of the tools in the SciML ecosystem, except for the specific cases + where the matrix-free operator is a constant! ## Features diff --git a/docs/src/interface.md b/docs/src/interface.md index d147fb60..0f093549 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -1,4 +1,4 @@ -# The `AbstractSciMLOperator` Interface +# [The `AbstractSciMLOperator` Interface](@id operator_interface) ```@docs SciMLOperators.AbstractSciMLOperator diff --git a/docs/src/premade_operators.md b/docs/src/premade_operators.md index 088b794e..c21ad464 100644 --- a/docs/src/premade_operators.md +++ b/docs/src/premade_operators.md @@ -1,4 +1,4 @@ -# Premade SciMLOperators +# [Premade SciMLOperators](@id premade_operators) ## Direct Operator Definitions diff --git a/docs/src/sciml.md b/docs/src/sciml.md deleted file mode 100644 index 261e9a84..00000000 --- a/docs/src/sciml.md +++ /dev/null @@ -1,79 +0,0 @@ -# Usage with `SciML` and beyond - -## Why `SciMLOperators`? - -Many functions, from linear solvers to differential equations, require -the use of matrix-free operators to achieve maximum performance in -many scenarios. `SciMLOperators.jl` defines the abstract interface for how -operators in the SciML ecosystem are supposed to be defined. It gives the -common set of functions and traits that solvers can rely on for properly -performing their tasks. Along with that, `SciMLOperators.jl` provides -definitions for the basic standard operators that are used as building -blocks for most tasks, simplifying the use of operators while also -demonstrating to users how such operators can be built and used in practice. - -`SciMLOperators.jl` has the design that is required to be used in -all scenarios of equation solvers. For example, Magnus integrators for -differential equations require defining an operator ``u' = A(t) u``, while -Munthe-Kaas methods require defining operators of the form ``u' = A(u) u``. -Thus, the operators need some form of time and state dependence, which the -solvers can update and query when they are non-constant -(`update_coefficients!`). Additionally, the operators need the ability to -act like “normal” functions for equation solvers. For example, if `A(v,u,p,t)` -has the same operation as `update_coefficients(A, u, p, t); A * v`, then `A` -can be used in any place where a differential equation definition -`(u,p,t) -> A(u, u, p, t)` is used without requiring the user or solver to do any extra -work. - -Another example is state-dependent mass matrices. `M(u,p,t)*u' = f(u,p,t)`. -When solving such an equation, the solver must understand how to "update M" -during operations, and thus the ability to update the state of `M` is a required -function in the interface. This is also required for the definition of Jacobians -`J(u,p,t)` in order to be properly used with Krylov methods inside of ODE solves -without reconstructing the matrix-free operator at each step. - -Thus while previous good efforts for matrix-free operators have existed -in the Julia ecosystem, such as -[LinearMaps.jl](https://github.com/JuliaLinearAlgebra/LinearMaps.jl), those -operator interfaces lack these aspects to actually be fully seamless -with downstream equation solvers. This necessitates the definition and use of -an extended operator interface with all of these properties, hence the -`AbstractSciMLOperator` interface. - -!!! warn - - This means that LinearMaps.jl is fundamentally lacking and is incompatible - with many of the tools in the SciML ecosystem, except for the specific cases - where the matrix-free operator is a constant! - -## Interoperability and extended Julia ecosystem - -`SciMLOperator.jl` overloads the `AbstractMatrix` interface for -`AbstractSciMLOperator`s, allowing seamless compatibility with -linear and nonlinear solvers. Further, due to the update functionality, -`AbstractSciMLOperator`s can represent an `ODEFunction` in `OrdinaryDiffEq.jl`, -and downstream packages. See tutorials for examples of usage with -`OrdinaryDiffEq.jl`, `LinearSolve.jl`, `NonlinearSolve.jl`. - -Further, the nonmutating update functionality allows gradient propagation -through `AbstractSciMLOperator`s, and is compatible with -automatic-differentiation libraries like -[`Zygote.jl`](https://github.com/SciML/DiffEqOperators.jl/tree/master). -An example of `Zygote.jl` usage with -[`Lux.jl`](https://github.com/LuxDL/Lux.jl) is also provided in the tutorials. - -Please make an issue [here](https://github.com/SciML/SciMLOperators.jl/issues) -if you come across an unexpected issue while using `SciMLOperators`. - -We provide below a list of packages that make use of `SciMLOperators`. -If you are using `SciMLOperators` in your work, feel free to create a PR -and add your package to this list. - - - [`SciML.ai`](https://sciml.ai/) ecosystem: `SciMLOperators` is compatible with, and utilized by every `SciML` package. - - - [`CalculustJL`](https://github.com/CalculustJL) packages use `SciMLOperators` to define matrix-free vector-calculus operators for solving partial differential equations. - - + [`CalculustCore.jl`](https://github.com/CalculustJL/CalculustCore.jl) - + [`FourierSpaces.jl`](https://github.com/CalculustJL/FourierSpaces.jl) - + [`NodalPolynomialSpaces.jl`](https://github.com/CalculustJL/NodalPolynomialSpaces.jl) - - `SparseDiffTools.jl` diff --git a/docs/src/tutorials/fftw.md b/docs/src/tutorials/fftw.md index 78c03258..596c3133 100644 --- a/docs/src/tutorials/fftw.md +++ b/docs/src/tutorials/fftw.md @@ -1,4 +1,4 @@ -# Wrap a Fourier transform with SciMLOperators +# [Wrap a Fourier transform with SciMLOperators](@id fft) In this tutorial, we will wrap a Fast Fourier Transform (FFT) in a SciMLOperator via the `FunctionOperator` interface. FFTs are commonly used algorithms for performing numerical diff --git a/docs/src/tutorials/getting_started.md b/docs/src/tutorials/getting_started.md new file mode 100644 index 00000000..0dd8481d --- /dev/null +++ b/docs/src/tutorials/getting_started.md @@ -0,0 +1,237 @@ +# Getting Started with Matrix-Free Operators in Julia + +SciMLOperators.jl is a package for defining operators for use in solvers. +One of the major use cases is to define matrix-free operators in cases where +using a matrix would be too memory expensive. In this tutorial we will walk +through the main features of SciMLOperators and get you going with matrix-free +and updating operators. + +## Simplest Operator: MatrixOperator + +Before we get into the deeper operators, let's show the simplest SciMLOperator: +`MatrixOperator`. `MatrixOperator` just turns a matrix into an `AbstractSciMLOperator`, +so it's not really a matrix-free operator but it's a starting point that is good for +understanding the interface and testing. To create a `MatrixOperator`, simply call the +constructor on a matrix: + +```@example getting_started +using SciMLOperators, LinearAlgebra +A = [-2.0 1 0 0 0 + 1 -2 1 0 0 + 0 1 -2 1 0 + 0 0 1 -2 1 + 0 0 0 1 -2] + +opA = MatrixOperator(A) +``` + +The operators can do [operations as defined in the operator interface](@ref operator_interface), for example, +matrix multiplication as the core action: + +```@example getting_started +v = [3.0,2.0,1.0,2.0,3.0] +opA*v +``` + +```@example getting_started +opA(v, nothing, nothing, nothing) # Call = opA*v +``` + +```@example getting_started +w = zeros(5) +mul!(w, opA, v) +``` + +```@example getting_started +α = 1.0; β = 1.0 +mul!(w, opA, v, α, β) # α*opA*v + β*w +``` + +and the inverse operation: + +```@example getting_started +opA \ v +``` + +```@example getting_started +ldiv!(w, lu(opA), v) +``` + +## State, Parameter, and Time-Dependent Operators + +Now let's define a `MatrixOperator` the is dependent on state, parameters, and time. +For example, let's make the operator `A .* u + dt*I` where `dt` is a parameter +and `u` is a state vector: + +```@example getting_started +A = [-2.0 1 0 0 0 + 1 -2 1 0 0 + 0 1 -2 1 0 + 0 0 1 -2 1 + 0 0 0 1 -2] + +function update_function!(B, u, p, t) + dt = p + B .= A .* u + dt*I +end + +u = Array(1:1.0:5); p = 0.1; t = 0.0 +opB = MatrixOperator(copy(A); update_func! = update_function!) +``` + +To update the operator, you would use `update_coefficients!(opB, u, p, t)`: + +```@example getting_started +update_coefficients!(opB, u, p, t) +``` + +We can use the interface to see what the current matrix is by converting to a standard matrix: + +```@example getting_started +convert(AbstractMatrix, opB) +``` + +And now applying the operator applies the updated one: + +```@example getting_started +opB*v +``` + +Or if you use the operator application, it will update and apply in one step: + +```@example getting_started +opB(v, Array(2:1.0:6), 0.5, nothing) # opB(u,p,t)*v +``` + +This is how for example, when an ODE solver asks for an operator `L(u,p,t)*u`, this is how +such an operator can be defined. Notice that the interface can be queried to understand +the traits of the operator, such as for example whether an operator is constant (does not +change w.r.t. `(u,p,t)`): + +```@example getting_started +isconstant(opA) +``` + +```@example getting_started +isconstant(opB) +``` + +## Matrix-Free Operators via FunctionOperator + +Now let's define the operators from above in a matrix-free way using `FunctionOperator`. +With `FunctionOperator`, we directly define the operator application function `opA(w,v,u,p,t)` +which means `w = opA(u,p,t)*v`. For exmaple we can do the following: + +```@example getting_started +function Afunc!(w,v,u,p,t) + w[1] = -2v[1] + v[2] + for i in 2:4 + w[i] = v[i-1] - 2v[i] + v[i+1] + end + w[5] = v[4] - 2v[5] + nothing +end + +function Afunc!(v,u,p,t) + w = zeros(5) + Afunc!(w,v,u,p,t) + w +end + +mfopA = FunctionOperator(Afunc!, zeros(5), zeros(5)) +``` + +Now `mfopA` acts just like `A*v` and thus `opA`: + +```@example getting_started +mfopA*v - opA*v +``` + +```@example getting_started +mfopA(v,u,p,t) - opA(v,u,p,t) +``` + +We can also create the state-dependent operator as well: + +```@example getting_started +function Bfunc!(w,v,u,p,t) + dt = p + w[1] = -(2*u[1]-dt)*v[1] + v[2]*u[1] + for i in 2:4 + w[i] = v[i-1]*u[i] - (2*u[i]-dt)*v[i] + v[i+1]*u[i] + end + w[5] = v[4]*u[5] - (2*u[5]-dt)*v[5] + nothing +end + +function Bfunc!(v,u,p,t) + w = zeros(5) + Bfunc!(w,v,u,p,t) + w +end + +mfopB = FunctionOperator(Bfunc!, zeros(5), zeros(5); u, p, t, isconstant=false) +``` + +```@example getting_started +opB(v, Array(2:1.0:6), 0.5, nothing) - mfopB(v, Array(2:1.0:6), 0.5, nothing) +``` + +## Operator Algebras + +While the operators are lazy operations and thus are not full matrices, you can still +do algebra on operators. This will construct a new lazy operator that will be able to +compute the same action as the composed function. For example, let's create `mfopB` +using `mfopA`. Recall that we defined this via `A .* u + dt*I`. Let's first create an +operator for `A .* u` (since right now there is not a built in operator for vector scaling, +but that would be a fantastic thing to add!): + +```@example getting_started +function Cfunc!(w,v,u,p,t) + w[1] = -2v[1] + v[2] + for i in 2:4 + w[i] = v[i-1] - 2v[i] + v[i+1] + end + w[5] = v[4] - 2v[5] + w .= w .* u + nothing +end + +function Cfunc!(v,u,p,t) + w = zeros(5) + Cfunc!(w,v,u,p,t) + w +end + +mfopC = FunctionOperator(Cfunc!, zeros(5), zeros(5)) +``` + +And now let's create the operator `mfopC + dt*I`. We can just directly build it: + +```@example getting_started +mfopD = mfopC + 0.5*I +``` + +SciMLOperators.jl uses an `IdentityOperator` and `ScalarOperator` instead of the Base +utilities, but the final composed operator acts just like the operator that was built: + +```@example getting_started +mfopB(v, Array(2:1.0:6), 0.5, nothing) - mfopD(v, Array(2:1.0:6), 0.5, nothing) +``` + +There are many cool things you can do with operator algebras, such as `kron` (Kronecker products), +adjoints, inverses, and more. For more information, see the [operator algebras tutorial](@ref operator_algebras). + +## Where to go next? + +Great! You now know how to be state/parameter/time-dependent operators and make them matrix-free, along with +doing algebras on operators. What's next? + +* Interested in more examples of building operators? See the example of [making a fast fourier transform linear operator](@ref fft) +* Interested in more operators ready to go? See the [Premade Operators page](@ref premade_operators) for all of the operators included with SciMLOperators. Note that there are also downstream packages that make new operators. +* Want to make your own SciMLOperator? See the [AbstractSciMLOperator interface page](@ref operator_interface) which describes the full interface. + +How do you use SciMLOperators? Check out the following downstream pages: + +* [Using SciMLOperators in LinearSolve.jl for matrix-free Krylov methods](https://docs.sciml.ai/LinearSolve/stable/tutorials/linear/) +* [Using SciMLOperators in OrdinaryDiffEq.jl for semi-linear ODE solvers](https://docs.sciml.ai/DiffEqDocs/stable/solvers/nonautonomous_linear_ode/) \ No newline at end of file diff --git a/docs/src/tutorials/operator_algebras.md b/docs/src/tutorials/operator_algebras.md new file mode 100644 index 00000000..387da89e --- /dev/null +++ b/docs/src/tutorials/operator_algebras.md @@ -0,0 +1,63 @@ +## [Demonstration of Operator Algebras and Kron](@id operator_algebras) + +Let `M`, `D`, `F` be matrix-based, diagonal-matrix-based, and function-based +`SciMLOperators` respectively. + +```julia +N = 4 +f = (v, u, p, t) -> u .* v + +M = MatrixOperator(rand(N, N)) +D = DiagonalOperator(rand(N)) +F = FunctionOperator(f, zeros(N), zeros(N)) +``` + +Then, the following codes just work. + +```julia +L1 = 2M + 3F + LinearAlgebra.I + rand(N, N) +L2 = D * F * M' +L3 = kron(M, D, F) +L4 = M \ D +L5 = [M; D]' * [M F; F D] * [F; D] +``` + +Each `L#` can be applied to `AbstractVector`s of appropriate sizes: + +```julia +p = nothing # parameter struct +t = 0.0 # time + +u = rand(N) +v = rand(N) +w = L1(v, u, p, t) # == L1 * v + +v_kron = rand(N^3) +w_kron = L3(v_kron, u, p, t) # == L3 * v_kron +``` + +For mutating operator evaluations, call `cache_operator` to generate an +in-place cache, so the operation is nonallocating. + +```julia +α, β = rand(2) + +# allocate cache +L2 = cache_operator(L2, u) +L4 = cache_operator(L4, u) + +# allocation-free evaluation +L2(w, v, u, p, t) # == mul!(w, L2, v) +L4(w, v, u, p, t, α, β) # == mul!(w, L4, v, α, β) +``` + +The calling signature `L(v, u, p, t)`, for out-of-place evaluations, is +equivalent to `L * v`, and the in-place evaluation `L(w, v, u, p, t, args...)` +is equivalent to `LinearAlgebra.mul!(w, L, v, args...)`, where the arguments +`u, p, t` are passed to `L` to update its state. More details are provided +in the operator update section below. + +The `(v, u, p, t)` calling signature is standardized over the `SciML` +ecosystem and is flexible enough to support use cases such as time-evolution +in ODEs, as well as sensitivity computation with respect to the parameter +object `p`. \ No newline at end of file diff --git a/src/basic.jl b/src/basic.jl index 1fee5bcd..8236ed41 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -980,8 +980,9 @@ function Base.resize!(L::InvertedOperator, n::Integer) end function update_coefficients(L::InvertedOperator, u, p, t) - @reset L.L = update_coefficients(L.L, u, p, t) - + if !isconstant(L.L) + @reset L.L = update_coefficients(L.L, u, p, t) + end L end diff --git a/src/batch.jl b/src/batch.jl index b3cc6ca8..bb2efbe5 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -26,8 +26,8 @@ struct BatchedDiagonalOperator{T, D, F, F!} <: AbstractSciMLOperator{T} end function DiagonalOperator(u::AbstractArray; - update_func = DEFAULT_UPDATE_FUNC, - update_func! = DEFAULT_UPDATE_FUNC, + update_func = nothing, + update_func! = nothing, accepted_kwargs = nothing) update_func = preprocess_update_func(update_func, accepted_kwargs) update_func! = preprocess_update_func(update_func!, accepted_kwargs) @@ -48,13 +48,13 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly update_func, update_func! = if isreal(L) L.update_func, L.update_func! else - uf = (L, u, p, t; kwargs...) -> conj(L.update_func(conj(L.diag), + uf = L.update_func === nothing ? nothing : (L, u, p, t; kwargs...) -> conj(L.update_func(conj(L.diag), u, p, t; kwargs...)) - uf! = (L, u, p, t; kwargs...) -> begin - L.update_func(conj!(L.diag), u, p, t; kwargs...) + uf! = L.update_func! === nothing ? nothing : (L, u, p, t; kwargs...) -> begin + L.update_func!(conj!(L.diag), u, p, t; kwargs...) conj!(L.diag) end uf, uf! @@ -83,12 +83,20 @@ end LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag))) function update_coefficients(L::BatchedDiagonalOperator, u, p, t; kwargs...) - @reset L.diag = L.update_func(L.diag, u, p, t; kwargs...) + if !isnothingfunc(L.update_func) + return @reset L.diag = L.update_func(L.diag, u, p, t; kwargs...) + elseif !isnothingfunc(L.update_func!) + L.update_func!(L.diag, u, p, t; kwargs...) + return L + end end function update_coefficients!(L::BatchedDiagonalOperator, u, p, t; kwargs...) - L.update_func!(L.diag, u, p, t; kwargs...) - + if !isnothingfunc(L.update_func!) + L.update_func!(L.diag, u, p, t; kwargs...) + elseif !isnothingfunc(L.update_func) + L.diag = L.update_func(L.diag, u, p, t; kwargs...) + end nothing end @@ -150,12 +158,13 @@ function LinearAlgebra.ldiv!(L::BatchedDiagonalOperator, u::AbstractVecOrMat) d = vec(L.diag) D = Diagonal(d) ldiv!(D, U) - u end function (L::BatchedDiagonalOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) - L = update_coefficients(L, u, p, t; kwargs...) + if !isconstant(L) + L = update_coefficients(L, u, p, t; kwargs...) + end L.diag .* v end diff --git a/src/func.jl b/src/func.jl index e852adc9..f2243912 100644 --- a/src/func.jl +++ b/src/func.jl @@ -198,7 +198,7 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`. * `isinplace` - `true` if the operator can be used is a mutating way with in-place allocations. This trait is inferred if no value is provided. * `outofplace` - `true` if the operator can be used is a non-mutating way with in-place allocations. This trait is inferred if no value is provided. * `has_mul5` - `true` if the operator provides a five-argument `mul!` via the signature `op(v, u, p, t, α, β; )`. This trait is inferred if no value is provided. -* `isconstant` - `true` if the operator is constant, and doesn't need to be updated via `update_coefficients[!]` during operator evaluation. +* `isconstant` - `true` if the operator is constant, and doesn't need to be updated via `update_coefficients[!]` during operator evaluation. Defaults to false. * `islinear` - `true` if the operator is linear. Defaults to `false`. * `isconvertible` - `true` a cheap `convert(AbstractMatrix, L.op)` method is available. Defaults to `false`. * `batch` - Boolean indicating if the input/output arrays comprise of batched column-vectors stacked in a matrix. If `true`, the input/output arrays must be `AbstractVecOrMat`s, and the length of the second dimension (the batch dimension) must be the same. The batch dimension is not involved in size computation. For example, with `batch = true`, and `size(output), size(input) = (M, K), (N, K)`, the `FunctionOperator` size is set to `(M, N)`. If `batch = false`, which is the default, the `input`/`output` arrays may of any size so long as `ndims(input) == ndims(output)`, and the `size` of `FunctionOperator` is set to `(length(input), length(output))`. diff --git a/src/matrix.jl b/src/matrix.jl index a04c65b5..4772e2c9 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -92,8 +92,8 @@ struct MatrixOperator{T, AT <: AbstractMatrix{T}, F, F!} <: AbstractSciMLOperato end function MatrixOperator(A; - update_func = DEFAULT_UPDATE_FUNC, - update_func! = DEFAULT_UPDATE_FUNC, + update_func = nothing, + update_func! = nothing, accepted_kwargs = nothing) update_func = preprocess_update_func(update_func, accepted_kwargs) update_func! = preprocess_update_func(update_func!, accepted_kwargs) @@ -127,12 +127,12 @@ for op in (:adjoint, @eval function Base.$op(L::MatrixOperator) isconstant(L) && return MatrixOperator($op(L.A)) - update_func = (A, u, p, t; kwargs...) -> $op(L.update_func($op(L.A), + update_func = L.update_func === nothing ? nothing : (A, u, p, t; kwargs...) -> $op(L.update_func($op(L.A), u, p, t; kwargs...)) - update_func! = (A, u, p, t; kwargs...) -> $op(L.update_func!($op(L.A), + update_func! = L.update_func! === nothing ? nothing : (A, u, p, t; kwargs...) -> $op(L.update_func!($op(L.A), u, p, t; @@ -148,12 +148,12 @@ end function Base.conj(L::MatrixOperator) isconstant(L) && return MatrixOperator(conj(L.A)) - update_func = (A, u, p, t; kwargs...) -> conj(L.update_func(conj(L.A), + update_func = L.update_func === nothing ? nothing : (A, u, p, t; kwargs...) -> conj(L.update_func(conj(L.A), u, p, t; kwargs...)) - update_func! = (A, u, p, t; kwargs...) -> begin + update_func! = L.update_func! === nothing ? nothing : (A, u, p, t; kwargs...) -> begin L.update_func!(conj!(L.A), u, p, t; kwargs...) conj!(L.A) end @@ -171,12 +171,20 @@ function isconstant(L::MatrixOperator) end function update_coefficients(L::MatrixOperator, u, p, t; kwargs...) - @reset L.A = L.update_func(L.A, u, p, t; kwargs...) + if !isnothingfunc(L.update_func) + @reset L.A = L.update_func(L.A, u, p, t; kwargs...) + elseif !isnothingfunc(L.update_func!) + L.update_func!(L.A, u, p, t; kwargs...) + end + L end function update_coefficients!(L::MatrixOperator, u, p, t; kwargs...) - L.update_func!(L.A, u, p, t; kwargs...) - + if !isnothingfunc(L.update_func!) + L.update_func!(L.A, u, p, t; kwargs...) + elseif !isnothingfunc(L.update_func) + L.A = L.update_func(L.A, u, p, t; kwargs...) + end nothing end @@ -281,8 +289,8 @@ $(UPDATE_COEFFS_WARNING) """ function DiagonalOperator(diag::AbstractVector; - update_func = DEFAULT_UPDATE_FUNC, - update_func! = DEFAULT_UPDATE_FUNC, + update_func = nothing, + update_func! = nothing, accepted_kwargs = nothing) diag_update_func = update_func_isconstant(update_func) ? update_func : (A, u, p, t; kwargs...) -> update_func(A.diag, u, p, t; kwargs...) |> @@ -429,7 +437,9 @@ LinearAlgebra.ldiv!(L::InvertibleOperator, u::AbstractVecOrMat) = ldiv!(L.F, u) # Out-of-place: v is action vector, u is update vector function (L::InvertibleOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) - L = update_coefficients(L, u, p, t; kwargs...) + if !isconstant(L) + L = update_coefficients(L, u, p, t; kwargs...) + end L.L * v end @@ -517,8 +527,8 @@ end function AffineOperator(A::Union{AbstractMatrix, AbstractSciMLOperator}, B::Union{AbstractMatrix, AbstractSciMLOperator}, b::AbstractArray; - update_func = DEFAULT_UPDATE_FUNC, - update_func! = DEFAULT_UPDATE_FUNC, + update_func = nothing, + update_func! = nothing, accepted_kwargs = nothing) @assert size(A, 1)==size(B, 1) "Dimension mismatch: A, B don't output vectors of same size" @@ -540,8 +550,8 @@ Represents the affine operation `w = I * v + I * b`. The update functions, documentation of `AffineOperator` for more details. """ function AddVector(b::AbstractVecOrMat; - update_func = DEFAULT_UPDATE_FUNC, - update_func! = DEFAULT_UPDATE_FUNC, + update_func = nothing, + update_func! = nothing, accepted_kwargs = nothing) N = size(b, 1) Id = IdentityOperator(N) @@ -560,8 +570,8 @@ Represents the affine operation `w = I * v + B * b`. The update functions, documentation of `AffineOperator` for more details. """ function AddVector(B, b::AbstractVecOrMat; - update_func = DEFAULT_UPDATE_FUNC, - update_func! = DEFAULT_UPDATE_FUNC, + update_func = nothing, + update_func! = nothing, accepted_kwargs = nothing) N = size(B, 1) Id = IdentityOperator(N) @@ -575,11 +585,16 @@ end function update_coefficients(L::AffineOperator, u, p, t; kwargs...) @reset L.A = update_coefficients(L.A, u, p, t; kwargs...) @reset L.B = update_coefficients(L.B, u, p, t; kwargs...) - @reset L.b = L.update_func(L.b, u, p, t; kwargs...) + if !isnothingfunc(L.update_func) + @reset L.b = L.update_func(L.b, u, p, t; kwargs...) + end + L end function update_coefficients!(L::AffineOperator, u, p, t; kwargs...) - L.update_func!(L.b, u, p, t; kwargs...) + if !isnothingfunc(L.update_func) + L.update_func!(L.b, u, p, t; kwargs...) + end for op in getops(L) update_coefficients!(op, u, p, t; kwargs...) end diff --git a/src/utils.jl b/src/utils.jl index 67b1440d..bcc5f9ec 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -14,16 +14,17 @@ arguments. Required in implementation of lazy `Base.adjoint`, struct NoKwargFilter end function preprocess_update_func(update_func, accepted_kwargs) - _update_func = (update_func === nothing) ? DEFAULT_UPDATE_FUNC : update_func _accepted_kwargs = (accepted_kwargs === nothing) ? () : accepted_kwargs # accepted_kwargs can be passed as nothing to indicate that we should not filter # (e.g. if the function already accepts all kwargs...). - return (_accepted_kwargs isa NoKwargFilter) ? _update_func : - FilterKwargs(_update_func, _accepted_kwargs) + return (_accepted_kwargs isa NoKwargFilter) ? update_func : + FilterKwargs(update_func, _accepted_kwargs) end + +update_func_isconstant(::Nothing) = true function update_func_isconstant(update_func) if update_func isa FilterKwargs - return update_func.f === DEFAULT_UPDATE_FUNC + return update_func.f === DEFAULT_UPDATE_FUNC || update_func.f === nothing else return update_func === DEFAULT_UPDATE_FUNC end @@ -52,6 +53,10 @@ function (f::FilterKwargs)(args...; kwargs...) filtered_kwargs = get_filtered_kwargs(kwargs, f.accepted_kwargs) f.f(args...; filtered_kwargs...) end + +isnothingfunc(f::FilterKwargs) = isnothingfunc(f.f) +isnothingfunc(f::Nothing) = true +isnothingfunc(f) = false # _unwrap_val(x) = x diff --git a/test/matrix.jl b/test/matrix.jl index 78f058c4..64170eb4 100644 --- a/test/matrix.jl +++ b/test/matrix.jl @@ -160,6 +160,42 @@ end orig_w = copy(w) L(w, v, u, p, t, α, β) @test w ≈ α * (A * v) + β * orig_w + + A = [-2.0 1 0 0 0 + 1 -2 1 0 0 + 0 1 -2 1 0 + 0 0 1 -2 1 + 0 0 0 1 -2] + v = [3.0,2.0,1.0,2.0,3.0] + opA = MatrixOperator(A) + + function update_function!(B, u, p, t) + dt = p + B .= A .* u + dt*I + end + + u = Array(1:1.0:5); p = 0.1; t = 0.0 + opB = MatrixOperator(copy(A); update_func! = update_function!) + + function Bfunc!(w,v,u,p,t) + dt = p + w[1] = -(2*u[1]-dt)*v[1] + v[2]*u[1] + for i in 2:4 + w[i] = v[i-1]*u[i] - (2*u[i]-dt)*v[i] + v[i+1]*u[i] + end + w[5] = v[4]*u[5] - (2*u[5]-dt)*v[5] + nothing + end + + function Bfunc!(v,u,p,t) + w = zeros(5) + Bfunc!(w,v,u,p,t) + w + end + + mfopB = FunctionOperator(Bfunc!, zeros(5), zeros(5); u, p, t, isconstant=false) + + @test iszero(opB(v, Array(2:1.0:6), 0.5, nothing) - mfopB(v, Array(2:1.0:6), 0.5, nothing)) end @testset "DiagonalOperator update test" begin