diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 8cde65ff..9040aa4d 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -4,6 +4,7 @@ on: pull_request: branches: - master + - 'release-' paths-ignore: - 'docs/**' push: @@ -13,19 +14,26 @@ on: - 'docs/**' concurrency: + # Skip intermediate builds: always, but for the master branch. + # Cancel intermediate builds: always, but for the master branch. group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref_name != github.event.repository.default_branch || github.ref != 'refs/tags/v*' }} + cancel-in-progress: ${{ github.ref != 'refs/heads/master' }} jobs: tests: name: "Tests" strategy: + fail-fast: false matrix: version: - "1" - "lts" - "pre" + group: + - Core + - Downstream uses: "SciML/.github/.github/workflows/tests.yml@v1" with: julia-version: "${{ matrix.version }}" + group: "${{ matrix.group }}" secrets: "inherit" diff --git a/Project.toml b/Project.toml index 4e35c8cb..51ca40f9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLOperators" uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" authors = ["Vedant Puri "] -version = "0.3.13" +version = "0.4.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/docs/Project.toml b/docs/Project.toml index 65284347..b59c1b70 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,4 +6,4 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" [compat] Documenter = "0.27, 1" FFTW = "1.7" -SciMLOperators = "0.2, 0.3" +SciMLOperators = "0.4" diff --git a/docs/src/index.md b/docs/src/index.md index 613f5ce0..22990460 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -32,7 +32,7 @@ Let `M`, `D`, `F` be matrix-based, diagonal-matrix-based, and function-based ```julia N = 4 -f = (u, p, t) -> u .* u +f = (v, u, p, t) -> u .* v M = MatrixOperator(rand(N, N)) D = DiagonalOperator(rand(N)) @@ -56,10 +56,11 @@ p = nothing # parameter struct t = 0.0 # time u = rand(N) -v = L1(u, p, t) # == L1 * u +v = rand(N) +w = L1(v, u, p, t) # == L1 * v -u_kron = rand(N^3) -v_kron = L3(u_kron, p, t) # == L3 * u_kron +v_kron = rand(N^3) +w_kron = L3(v_kron, u, p, t) # == L3 * v_kron ``` For mutating operator evaluations, call `cache_operator` to generate an @@ -73,21 +74,17 @@ L2 = cache_operator(L2, u) L4 = cache_operator(L4, u) # allocation-free evaluation -L2(v, u, p, t) # == mul!(v, L2, u) -L4(v, u, p, t, α, β) # == mul!(v, L4, u, α, β) +L2(w, v, u, p, t) # == mul!(w, L2, v) +L4(w, v, u, p, t, α, β) # == mul!(w, L4, v, α, β) ``` -The calling signature `L(u, p, t)`, for out-of-place evaluations, is -equivalent to `L * u`, and the in-place evaluation `L(v, u, p, t, args...)` -is equivalent to `LinearAlgebra.mul!(v, L, u, args...)`, where the arguments -`p, t` are passed to `L` to update its state. More details are provided -in the operator update section below. While overloads to `Base.*` -and `LinearAlgebra.mul!` are available, where a `SciMLOperator` behaves -like an `AbstractMatrix`, we recommend sticking with the -`L(u, p, t)`, `L(v, u, p, t)`, `L(v, u, p, t, α, β)` calling signatures -as the latter internally update the operator state. - -The `(u, p, t)` calling signature is standardized over the `SciML` +The calling signature `L(v, u, p, t)`, for out-of-place evaluations, is +equivalent to `L * v`, and the in-place evaluation `L(w, v, u, p, t, args...)` +is equivalent to `LinearAlgebra.mul!(w, L, v, args...)`, where the arguments +`u, p, t` are passed to `L` to update its state. More details are provided +in the operator update section below. + +The `(v, u, p, t)` calling signature is standardized over the `SciML` ecosystem and is flexible enough to support use cases such as time-evolution in ODEs, as well as sensitivity computation with respect to the parameter object `p`. diff --git a/docs/src/interface.md b/docs/src/interface.md index f1f524c5..d147fb60 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -1,152 +1,7 @@ # The `AbstractSciMLOperator` Interface -## Formal Properties of SciMLOperators - -These are the formal properties that an `AbstractSciMLOperator` should obey -for it to work in the solvers. - - 1. An `AbstractSciMLOperator` represents a linear or nonlinear operator, with input/output - being `AbstractArray`s. Specifically, a SciMLOperator, `L`, of size `(M, N)` accepts an - input argument `u` with leading length `N`, i.e. `size(u, 1) == N`, and returns an - `AbstractArray` of the same dimension with leading length `M`, i.e. `size(L * u, 1) == M`. - 2. SciMLOperators can be applied to an `AbstractArray` via overloaded `Base.*`, or - the in-place `LinearAlgebra.mul!`. Additionally, operators are allowed to be time, - or parameter dependent. The state of a SciMLOperator can be updated by calling - the mutating function `update_coefficients!(L, u, p, t)` where `p` represents - parameters, and `t`, time. Calling a SciMLOperator as `L(du, u, p, t)` or out-of-place - `L(u, p, t)` will automatically update the state of `L` before applying it to `u`. - `L(u, p, t)` is the same operation as `L(u, p, t) * u`. - 3. To support the update functionality, we have lazily implemented a comprehensive operator - algebra. That means a user can add, subtract, scale, compose and invert SciMLOperators, - and the state of the resultant operator would be updated as expected upon calling - `L(du, u, p, t)` or `L(u, p, t)` so long as an update function is provided for the - component operators. - -## Overloaded Traits - -Thanks to overloads defined for evaluation methods and traits in -`Base`, `LinearAlgebra`, the behavior of a `SciMLOperator` is -indistinguishable from an `AbstractMatrix`. These operators can be -passed to linear solver packages, and even to ordinary differential -equation solvers. The list of overloads to the `AbstractMatrix` -interface includes, but is not limited to, the following: - - - `Base: size, zero, one, +, -, *, /, \, ∘, inv, adjoint, transpose, convert` - - `LinearAlgebra: mul!, ldiv!, lmul!, rmul!, factorize, issymmetric, ishermitian, isposdef` - - `SparseArrays: sparse, issparse` - -## Multidimensional arrays and batching - -SciMLOperator can also be applied to `AbstractMatrix` subtypes where -operator-evaluation is done column-wise. - -```julia -K = 10 -u_mat = rand(N, K) - -v_mat = F(u_mat, p, t) # == mul!(v_mat, F, u_mat) -size(v_mat) == (N, K) # true -``` - -`L#` can also be applied to `AbstractArray`s that are not -`AbstractVecOrMat`s so long as their size in the first dimension is appropriate -for matrix-multiplication. Internally, `SciMLOperator`s reshapes an -`N`-dimensional array to an `AbstractMatrix`, and applies the operator via -matrix-multiplication. - -## Operator update - -This package can also be used to write time-dependent, and -parameter-dependent operators, whose state can be updated per -a user-defined function. -The updates can be done in-place, i.e. by mutating the object, -or out-of-place, i.e. in a non-mutating, `Zygote`-compatible way. - -For example, - -```julia -u = rand(N) -p = rand(N) -t = rand() - -# out-of-place update -mat_update_func = (A, u, p, t) -> t * (p * p') -sca_update_func = (a, u, p, t) -> t * sum(p) - -M = MatrixOperator(zero(N, N); update_func = mat_update_func) -α = ScalarOperator(zero(Float64); update_func = sca_update_func) - -L = α * M -L = cache_operator(L, u) - -# L is initialized with zero state -L * u == zeros(N) # true - -# update operator state with `(u, p, t)` -L = update_coefficients(L, u, p, t) -# and multiply -L * u != zeros(N) # true - -# updates state and evaluates L at (u, p, t) -L(u, p, t) != zeros(N) # true -``` - -The out-of-place evaluation function `L(u, p, t)` calls -`update_coefficients` under the hood, which recursively calls -the `update_func` for each component `SciMLOperator`. -Therefore, the out-of-place evaluation function is equivalent to -calling `update_coefficients` followed by `Base.*`. Notice that -the out-of-place evaluation does not return the updated operator. - -On the other hand, the in-place evaluation function, `L(v, u, p, t)`, -mutates `L`, and is equivalent to calling `update_coefficients!` -followed by `mul!`. The in-place update behavior works the same way, -with a few ``s appended here and there. For example, - -```julia -v = rand(N) -u = rand(N) -p = rand(N) -t = rand() - -# in-place update -_A = rand(N, N) -_d = rand(N) -mat_update_func! = (A, u, p, t) -> (copy!(A, _A); lmul!(t, A); nothing) -diag_update_func! = (diag, u, p, t) -> copy!(diag, N) - -M = MatrixOperator(zero(N, N); update_func! = mat_update_func!) -D = DiagonalOperator(zero(N); update_func! = diag_update_func!) - -L = D * M -L = cache_operator(L, u) - -# L is initialized with zero state -L * u == zeros(N) # true - -# update L in-place -update_coefficients!(L, u, p, t) -# and multiply -mul!(v, u, p, t) != zero(N) # true - -# updates L in-place, and evaluates at (u, p, t) -L(v, u, p, t) != zero(N) # true -``` - -The update behavior makes this package flexible enough to be used -in `OrdinaryDiffEq`. As the parameter object `p` is often reserved -for sensitivity computation via automatic-differentiation, a user may -prefer to pass in state information via other arguments. For that -reason, we allow update functions with arbitrary keyword arguments. - -```julia -mat_update_func = (A, u, p, t; scale = 0.0) -> scale * (p * p') - -M = MatrixOperator(zero(N, N); update_func = mat_update_func, - accepted_kwargs = (:state,)) - -M(u, p, t) == zeros(N) # true -M(u, p, t; scale = 1.0) != zero(N) +```@docs +SciMLOperators.AbstractSciMLOperator ``` ## Interface API Reference @@ -219,6 +74,6 @@ update_coefficients!(γ, nothing, nothing, nothing; my_special_scaling = 7.0) @show γ * [2.0] # Use operator application form -@show γ([2.0], nothing, nothing; my_special_scaling = 5.0) +@show γ([2.0], nothing, nothing, nothing; my_special_scaling = 5.0) nothing # hide ``` diff --git a/docs/src/premade_operators.md b/docs/src/premade_operators.md index 4b2664f5..088b794e 100644 --- a/docs/src/premade_operators.md +++ b/docs/src/premade_operators.md @@ -3,7 +3,7 @@ ## Direct Operator Definitions ```@docs -ScalarOperator.IdentityOperator +SciMLOperators.IdentityOperator SciMLOperators.NullOperator ScalarOperator MatrixOperator diff --git a/docs/src/sciml.md b/docs/src/sciml.md index 23f61027..261e9a84 100644 --- a/docs/src/sciml.md +++ b/docs/src/sciml.md @@ -19,11 +19,20 @@ Munthe-Kaas methods require defining operators of the form ``u' = A(u) u``. Thus, the operators need some form of time and state dependence, which the solvers can update and query when they are non-constant (`update_coefficients!`). Additionally, the operators need the ability to -act like “normal” functions for equation solvers. For example, if `A(u,p,t)` -has the same operation as `update_coefficients(A, u, p, t); A * u`, then `A` +act like “normal” functions for equation solvers. For example, if `A(v,u,p,t)` +has the same operation as `update_coefficients(A, u, p, t); A * v`, then `A` can be used in any place where a differential equation definition -`f(u, p, t)` is used without requiring the user or solver to do any extra -work. Thus while previous good efforts for matrix-free operators have existed +`(u,p,t) -> A(u, u, p, t)` is used without requiring the user or solver to do any extra +work. + +Another example is state-dependent mass matrices. `M(u,p,t)*u' = f(u,p,t)`. +When solving such an equation, the solver must understand how to "update M" +during operations, and thus the ability to update the state of `M` is a required +function in the interface. This is also required for the definition of Jacobians +`J(u,p,t)` in order to be properly used with Krylov methods inside of ODE solves +without reconstructing the matrix-free operator at each step. + +Thus while previous good efforts for matrix-free operators have existed in the Julia ecosystem, such as [LinearMaps.jl](https://github.com/JuliaLinearAlgebra/LinearMaps.jl), those operator interfaces lack these aspects to actually be fully seamless @@ -31,10 +40,11 @@ with downstream equation solvers. This necessitates the definition and use of an extended operator interface with all of these properties, hence the `AbstractSciMLOperator` interface. -Some packages providing similar functionality are +!!! warn - - [LinearMaps.jl](https://github.com/JuliaLinearAlgebra/LinearMaps.jl) - - [`DiffEqOperators.jl`](https://github.com/SciML/DiffEqOperators.jl/tree/master) (deprecated) + This means that LinearMaps.jl is fundamentally lacking and is incompatible + with many of the tools in the SciML ecosystem, except for the specific cases + where the matrix-free operator is a constant! ## Interoperability and extended Julia ecosystem diff --git a/docs/src/tutorials/fftw.md b/docs/src/tutorials/fftw.md index 13af0160..78c03258 100644 --- a/docs/src/tutorials/fftw.md +++ b/docs/src/tutorials/fftw.md @@ -16,18 +16,18 @@ L = 2π dx = L / n x = range(start = -L / 2, stop = L / 2 - dx, length = n) |> Array -u = @. sin(5x)cos(7x); -du = @. 5cos(5x)cos(7x) - 7sin(5x)sin(7x); +v = @. sin(5x)cos(7x); +w = @. 5cos(5x)cos(7x) - 7sin(5x)sin(7x); k = rfftfreq(n, 2π * n / L) |> Array m = length(k) P = plan_rfft(x) -fwd(u, p, t) = P * u -bwd(u, p, t) = P \ u +fwd(v, u, p, t) = P * v +bwd(v, u, p, t) = P \ v -fwd(du, u, p, t) = mul!(du, P, u) -bwd(du, u, p, t) = ldiv!(du, P, u) +fwd(w, v, u, p, t) = mul!(w, P, v) +bwd(w, v, u, p, t) = ldiv!(w, P, v) F = FunctionOperator(fwd, x, im * k; T = ComplexF64, op_adjoint = bwd, @@ -38,10 +38,10 @@ F = FunctionOperator(fwd, x, im * k; ik = im * DiagonalOperator(k) Dx = F \ ik * F -Dx = cache_operator(Dx, x) +Dx = cache_operator(Dx, v) -@show ≈(Dx * u, du; atol = 1e-8) -@show ≈(mul!(copy(u), Dx, u), du; atol = 1e-8) +@show ≈(Dx * v, w; atol = 1e-8) +@show ≈(mul!(copy(w), Dx, v), w; atol = 1e-8) ``` ## Explanation @@ -61,8 +61,8 @@ n = 256 dx = L / n x = range(start = -L / 2, stop = L / 2 - dx, length = n) |> Array -u = @. sin(5x)cos(7x); -du = @. 5cos(5x)cos(7x) - 7sin(5x)sin(7x); +v = @. sin(5x)cos(7x); +w = @. 5cos(5x)cos(7x) - 7sin(5x)sin(7x); ``` Now, we define the Fourier transform. Since our input is purely Real, we use the real @@ -79,15 +79,15 @@ P = plan_rfft(x) Now we are ready to define our wrapper for the FFT object. To `FunctionOperator`, we pass the in-place forward application of the transform, -`(du,u,p,t) -> mul!(du, transform, u)`, its inverse application, -`(du,u,p,t) -> ldiv!(du, transform, u)`, as well as input and output prototype vectors. +`(w,v,u,p,t) -> mul!(w, transform, v)`, its inverse application, +`(w,v,u,p,t) -> ldiv!(w, transform, v)`, as well as input and output prototype vectors. ```@example fft_explanation -fwd(u, p, t) = P * u -bwd(u, p, t) = P \ u +fwd(v, u, p, t) = P * v +bwd(v, u, p, t) = P \ v -fwd(du, u, p, t) = mul!(du, P, u) -bwd(du, u, p, t) = ldiv!(du, P, u) +fwd(w, v, u, p, t) = mul!(w, P, v) +bwd(w, v, u, p, t) = ldiv!(w, P, v) F = FunctionOperator(fwd, x, im * k; T = ComplexF64, op_adjoint = bwd, op_inverse = bwd, @@ -106,6 +106,6 @@ Dx = F \ ik * F Dx = cache_operator(Dx, x) -@show ≈(Dx * u, du; atol = 1e-8) -@show ≈(mul!(copy(u), Dx, u), du; atol = 1e-8) +@show ≈(Dx * v, w; atol = 1e-8) +@show ≈(mul!(copy(w), Dx, v), w; atol = 1e-8) ``` diff --git a/src/SciMLOperators.jl b/src/SciMLOperators.jl index 83dd366a..362d9f00 100644 --- a/src/SciMLOperators.jl +++ b/src/SciMLOperators.jl @@ -26,21 +26,29 @@ Subtypes of `AbstractSciMLOperator` represent linear, nonlinear, time-dependent operators acting on vectors, or matrix column-vectors. A lazy operator algebra is also defined for `AbstractSciMLOperator`s. -# Interface +## Mathematical Notation -An `AbstractSciMLOperator` can be called like a function. This behaves -like multiplication by the linear operator represented by the -`AbstractSciMLOperator`. Possible signatures are +An `AbstractSciMLOperator` ``L`` is an operator which is used to represent +the following type of equation: -- `L(v, u, p, t)` for in-place operator evaluation -- `v = L(u, p, t)` for out-of-place operator evaluation +```math +w = L(u,p,t)[v] +``` -Operator evaluation methods update its coefficients with `(u, p, t)` -information using the `update_coefficients(!)` method. The methods -are exported and can be called as follows: +where `L[v]` is the operator application of ``L`` on the vector ``v``. -- `update_coefficients!(L, u, p, t)` for out-of-place operator update -- `L = update_coefficients(L, u, p, t)` for in-place operator update +## Interface + +An `AbstractSciMLOperator` can be called like a function in the following ways: + +- `L(v, u, p, t)` - Out-of-place application where `v` is the action vector and `u` is the update vector +- `L(w, v, u, p, t)` - In-place application where `w` is the destination, `v` is the action vector, and `u` is the update vector +- `L(w, v, u, p, t, α, β)` - In-place application with scaling: `w = α*(L*v) + β*w` + +Operator state can be updated separately from application: + +- `update_coefficients!(L, u, p, t)` for in-place operator update +- `L = update_coefficients(L, u, p, t)` for out-of-place operator update SciMLOperators also overloads `Base.*`, `LinearAlgebra.mul!`, `LinearAlgebra.ldiv!` for operator evaluation without updating operator state. @@ -49,7 +57,133 @@ Allocation-free methods, suffixed with a `!` often need cache arrays. To precache an `AbstractSciMLOperator`, call the function `L = cache_operator(L, input_vector)`. -# Methods +## Overloaded Actions + +The behavior of a `SciMLOperator` is +indistinguishable from an `AbstractMatrix`. These operators can be +passed to linear solver packages, and even to ordinary differential +equation solvers. The list of overloads to the `AbstractMatrix` +interface includes, but is not limited to, the following: + + - `Base: size, zero, one, +, -, *, /, \\, ∘, inv, adjoint, transpose, convert` + - `LinearAlgebra: mul!, ldiv!, lmul!, rmul!, factorize, issymmetric, ishermitian, isposdef` + - `SparseArrays: sparse, issparse` + +## Multidimensional arrays and batching + +SciMLOperator can also be applied to `AbstractMatrix` subtypes where +operator-evaluation is done column-wise. + +```julia +K = 10 +u_mat = rand(N, K) + +v_mat = F(u_mat, p, t) # == mul!(v_mat, F, u_mat) +size(v_mat) == (N, K) # true +``` + +`L` can also be applied to `AbstractArray`s that are not +`AbstractVecOrMat`s so long as their size in the first dimension is appropriate +for matrix-multiplication. Internally, `SciMLOperator`s reshapes an +`N`-dimensional array to an `AbstractMatrix`, and applies the operator via +matrix-multiplication. + +## Operator update + +This package can also be used to write state-dependent, time-dependent, and +parameter-dependent operators, whose state can be updated per +a user-defined function. +The updates can be done in-place, i.e. by mutating the object, +or out-of-place, i.e. in a non-mutating, `Zygote`-compatible way. + +For example, + +```julia +u = rand(N) +p = rand(N) +t = rand() + +# out-of-place update +mat_update_func = (A, u, p, t) -> t * (p * u') +sca_update_func = (a, u, p, t) -> t * sum(p) + +M = MatrixOperator(zero(N, N); update_func = mat_update_func) +α = ScalarOperator(zero(Float64); update_func = sca_update_func) + +L = α * M +L = cache_operator(L, v) + +# L is initialized with zero state +L * v == zeros(N) # true + +# update operator state with `(u, p, t)` +L = update_coefficients(L, u, p, t) +# and multiply +L * v != zeros(N) # true + +# updates state and evaluates L*v at (u, p, t) +L(v, u, p, t) != zeros(N) # true +``` + +The out-of-place evaluation function `L(v, u, p, t)` calls +`update_coefficients` under the hood, which recursively calls +the `update_func` for each component `SciMLOperator`. +Therefore, the out-of-place evaluation function is equivalent to +calling `update_coefficients` followed by `Base.*`. Notice that +the out-of-place evaluation does not return the updated operator. + +On the other hand, the in-place evaluation function, `L(w, v, u, p, t)`, +mutates `L`, and is equivalent to calling `update_coefficients!` +followed by `mul!`. The in-place update behavior works the same way, +with a few ``s appended here and there. For example, + +```julia +w = rand(N) +v = rand(N) +u = rand(N) +p = rand(N) +t = rand() + +# in-place update +_A = rand(N, N) +_d = rand(N) +mat_update_func! = (A, u, p, t) -> (copy!(A, _A); lmul!(t, A); nothing) +diag_update_func! = (diag, u, p, t) -> copy!(diag, N) + +M = MatrixOperator(zero(N, N); update_func! = mat_update_func!) +D = DiagonalOperator(zero(N); update_func! = diag_update_func!) + +L = D * M +L = cache_operator(L, v) + +# L is initialized with zero state +L * v == zeros(N) # true + +# update L in-place +update_coefficients!(L, v, p, t) +# and multiply +mul!(w, v, u, p, t) != zero(N) # true + +# updates L in-place, and evaluates w=L*v at (u, p, t) +L(w, v, u, p, t) != zero(N) # true +``` + +The update behavior makes this package flexible enough to be used +in `OrdinaryDiffEq`. As the parameter object `p` is often reserved +for sensitivity computation via automatic-differentiation, a user may +prefer to pass in state information via other arguments. For that +reason, we allow update functions with arbitrary keyword arguments. + +```julia +mat_update_func = (A, u, p, t; scale = 0.0) -> scale * (p * u') + +M = MatrixOperator(zero(N, N); update_func = mat_update_func, + accepted_kwargs = (:state,)) + +M(v, u, p, t) == zeros(N) # true +M(v, u, p, t; scale = 1.0) != zero(N) +``` + """ abstract type AbstractSciMLOperator{T} end diff --git a/src/basic.jl b/src/basic.jl index 666cbe94..1fee5bcd 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -1,7 +1,7 @@ """ $(TYPEDEF) -Operator representing the identity function `id(u) = u` +Operator representing the identity function `id(v) = v` """ struct IdentityOperator <: AbstractSciMLOperator{Bool} len::Int @@ -38,35 +38,56 @@ has_ldiv!(::IdentityOperator) = true # operator application for op in (:*, :\) - @eval function Base.$op(ii::IdentityOperator, u::AbstractVecOrMat) - @assert size(u, 1) == ii.len - copy(u) + @eval function Base.$op(ii::IdentityOperator, v::AbstractVecOrMat) + @assert size(v, 1) == ii.len + copy(v) end end @inline function LinearAlgebra.mul!( - v::AbstractVecOrMat, ii::IdentityOperator, u::AbstractVecOrMat) - @assert size(u, 1) == ii.len - copy!(v, u) + w::AbstractVecOrMat, ii::IdentityOperator, v::AbstractVecOrMat) + @assert size(v, 1) == ii.len + copy!(w, v) end -@inline function LinearAlgebra.mul!(v::AbstractVecOrMat, +@inline function LinearAlgebra.mul!(w::AbstractVecOrMat, ii::IdentityOperator, - u::AbstractVecOrMat, + v::AbstractVecOrMat, α, β) - @assert size(u, 1) == ii.len - mul!(v, I, u, α, β) + @assert size(v, 1) == ii.len + mul!(w, I, v, α, β) end -function LinearAlgebra.ldiv!(v::AbstractVecOrMat, ii::IdentityOperator, u::AbstractVecOrMat) - @assert size(u, 1) == ii.len - copy!(v, u) +function LinearAlgebra.ldiv!(w::AbstractVecOrMat, ii::IdentityOperator, v::AbstractVecOrMat) + @assert size(v, 1) == ii.len + copy!(w, v) end -function LinearAlgebra.ldiv!(ii::IdentityOperator, u::AbstractVecOrMat) - @assert size(u, 1) == ii.len - u +function LinearAlgebra.ldiv!(ii::IdentityOperator, v::AbstractVecOrMat) + @assert size(v, 1) == ii.len + v +end + +# Out-of-place: v is action vector, u is update vector +function (ii::IdentityOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + @assert size(v, 1) == ii.len + update_coefficients(ii, u, p, t; kwargs...) + copy(v) +end + +# In-place: w is destination, v is action vector, u is update vector +function (ii::IdentityOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + @assert size(v, 1) == ii.len + update_coefficients!(ii, u, p, t; kwargs...) + copy!(w, v) +end + +# In-place with scaling: w = α*(ii*v) + β*w +function (ii::IdentityOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + @assert size(v, 1) == ii.len + update_coefficients!(ii, u, p, t; kwargs...) + mul!(w, I, v, α, β) end # operator fusion with identity returns operator itself @@ -95,7 +116,7 @@ end """ $(TYPEDEF) -Operator representing the null function `n(u) = 0 * u` +Operator representing the null function `n(v) = 0 * v` """ struct NullOperator <: AbstractSciMLOperator{Bool} len::Int @@ -130,20 +151,40 @@ has_adjoint(::NullOperator) = true has_mul!(::NullOperator) = true # operator application -Base.:*(nn::NullOperator, u::AbstractVecOrMat) = (@assert size(u, 1) == nn.len; zero(u)) +Base.:*(nn::NullOperator, v::AbstractVecOrMat) = (@assert size(v, 1) == nn.len; zero(v)) -function LinearAlgebra.mul!(v::AbstractVecOrMat, nn::NullOperator, u::AbstractVecOrMat) - @assert size(u, 1) == size(v, 1) == nn.len - lmul!(false, v) +function LinearAlgebra.mul!(w::AbstractVecOrMat, nn::NullOperator, v::AbstractVecOrMat) + @assert size(v, 1) == size(w, 1) == nn.len + lmul!(false, w) end -function LinearAlgebra.mul!(v::AbstractVecOrMat, +function LinearAlgebra.mul!(w::AbstractVecOrMat, nn::NullOperator, - u::AbstractVecOrMat, + v::AbstractVecOrMat, α, β) - @assert size(u, 1) == size(v, 1) == nn.len - lmul!(β, v) + @assert size(v, 1) == size(w, 1) == nn.len + lmul!(β, w) +end + +# Out-of-place: v is action vector, u is update vector +function (nn::NullOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + @assert size(v, 1) == nn.len + zero(v) +end + +# In-place: w is destination, v is action vector, u is update vector +function (nn::NullOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + @assert size(v, 1) == nn.len + lmul!(false, w) + w +end + +# In-place with scaling: w = α*(nn*v) + β*w +function (nn::NullOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + @assert size(v, 1) == nn.len + lmul!(β, w) + w end # operator fusion, composition @@ -177,7 +218,7 @@ $TYPEDEF ScaledOperator - (λ L)*(u) = λ * L(u) + (λ L)*(v) = λ * L(v) """ struct ScaledOperator{T, λType, @@ -218,6 +259,26 @@ for T in SCALINGNUMBERTYPES end end +# Special cases for constant scalars. These simplify the structure when applicable +for T in SCALINGNUMBERTYPES[2:end] + @eval function Base.:*(α::$T, L::ScaledOperator) + isconstant(L.λ) && return ScaledOperator(α * L.λ, L.L) + return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule + end + @eval function Base.:*(L::ScaledOperator, α::$T) + isconstant(L.λ) && return ScaledOperator(α * L.λ, L.L) + return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule + end + @eval function Base.:*(α::$T, L::MatrixOperator) + isconstant(L) && return MatrixOperator(α * L.A) + return ScaledOperator(α, L) # Going back to the generic case + end + @eval function Base.:*(L::MatrixOperator, α::$T) + isconstant(L) && return MatrixOperator(α * L.A) + return ScaledOperator(α, L) # Going back to the generic case + end +end + Base.:-(L::AbstractSciMLOperator) = ScaledOperator(-true, L) Base.:+(L::AbstractSciMLOperator) = L @@ -264,9 +325,9 @@ has_mul!(L::ScaledOperator) = has_mul!(L.L) has_ldiv(L::ScaledOperator) = has_ldiv(L.L) & !iszero(L.λ) has_ldiv!(L::ScaledOperator) = has_ldiv!(L.L) & !iszero(L.λ) -function cache_internals(L::ScaledOperator, u::AbstractVecOrMat) - @reset L.L = cache_operator(L.L, u) - @reset L.λ = cache_operator(L.λ, u) +function cache_internals(L::ScaledOperator, v::AbstractVecOrMat) + @reset L.L = cache_operator(L.L, v) + @reset L.λ = cache_operator(L.λ, v) L end @@ -286,40 +347,77 @@ for fact in (:lu, :lu!, end # operator application, inversion -Base.:*(L::ScaledOperator, u::AbstractVecOrMat) = L.λ * (L.L * u) -Base.:\(L::ScaledOperator, u::AbstractVecOrMat) = L.λ \ (L.L \ u) +Base.:*(L::ScaledOperator, v::AbstractVecOrMat) = L.λ * (L.L * v) +Base.:\(L::ScaledOperator, v::AbstractVecOrMat) = L.λ \ (L.L \ v) @inline function LinearAlgebra.mul!( - v::AbstractVecOrMat, L::ScaledOperator, u::AbstractVecOrMat) - iszero(L.λ) && return lmul!(false, v) + w::AbstractVecOrMat, L::ScaledOperator, v::AbstractVecOrMat) + iszero(L.λ) && return lmul!(false, w) a = convert(Number, L.λ) - mul!(v, L.L, u, a, false) + mul!(w, L.L, v, a, false) end -@inline function LinearAlgebra.mul!(v::AbstractVecOrMat, +@inline function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ScaledOperator, - u::AbstractVecOrMat, + v::AbstractVecOrMat, α, β) - iszero(L.λ) && return lmul!(β, v) + iszero(L.λ) && return lmul!(β, w) a = convert(Number, L.λ * α) - mul!(v, L.L, u, a, β) + mul!(w, L.L, v, a, β) +end + +function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::ScaledOperator, v::AbstractVecOrMat) + ldiv!(w, L.L, v) + ldiv!(L.λ, w) end -function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::ScaledOperator, u::AbstractVecOrMat) - ldiv!(v, L.L, u) +function LinearAlgebra.ldiv!(L::ScaledOperator, v::AbstractVecOrMat) ldiv!(L.λ, v) + ldiv!(L.L, v) +end + +# Out-of-place: v is action vector, u is update vector +function (L::ScaledOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + L = update_coefficients(L, u, p, t; kwargs...) + if iszero(L.λ) + return zero(v) + else + return L.λ * (L.L * v) + end +end + +# In-place: w is destination, v is action vector, u is update vector +function (L::ScaledOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + if iszero(L.λ) + lmul!(false, w) + return w + else + a = convert(Number, L.λ) + mul!(w, L.L, v, a, false) + return w + end end -function LinearAlgebra.ldiv!(L::ScaledOperator, u::AbstractVecOrMat) - ldiv!(L.λ, u) - ldiv!(L.L, u) +# In-place with scaling: w = α*(L*v) + β*w +function (L::ScaledOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + if iszero(L.λ) + lmul!(β, w) + return w + else + a = convert(Number, L.λ * α) + mul!(w, L.L, v, a, β) + return w + end end + """ Lazy operator addition - (A1 + A2 + A3...)u = A1*u + A2*u + A3*u .... + (A1 + A2 + A3...)v = A1*v + A2*v + A3*v .... """ struct AddedOperator{T, O <: Tuple{Vararg{AbstractSciMLOperator}} @@ -472,12 +570,12 @@ islinear(L::AddedOperator) = all(islinear, getops(L)) Base.iszero(L::AddedOperator) = all(iszero, getops(L)) has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops) -@generated function cache_internals(L::AddedOperator, u::AbstractVecOrMat) +@generated function cache_internals(L::AddedOperator, v::AbstractVecOrMat) ops_types = L.parameters[2].parameters N = length(ops_types) quote Base.@nexprs $N i->begin - @reset L.ops[i] = cache_operator(L.ops[i], u) + @reset L.ops[i] = cache_operator(L.ops[i], v) end L end @@ -486,46 +584,74 @@ end getindex(L::AddedOperator, i::Int) = sum(op -> op[i], L.ops) getindex(L::AddedOperator, I::Vararg{Int, N}) where {N} = sum(op -> op[I...], L.ops) -function Base.:*(L::AddedOperator, u::AbstractVecOrMat) - sum(op -> iszero(op) ? zero(u) : op * u, L.ops) +function Base.:*(L::AddedOperator, v::AbstractVecOrMat) + sum(op -> iszero(op) ? zero(v) : op * v, L.ops) end @generated function LinearAlgebra.mul!( - v::AbstractVecOrMat, L::AddedOperator, u::AbstractVecOrMat) + w::AbstractVecOrMat, L::AddedOperator, v::AbstractVecOrMat) ops_types = L.parameters[2].parameters N = length(ops_types) quote - mul!(v, L.ops[1], u) + mul!(w, L.ops[1], v) Base.@nexprs $(N - 1) i->begin - mul!(v, L.ops[i + 1], u, true, true) + mul!(w, L.ops[i + 1], v, true, true) end - v + w end end -@generated function LinearAlgebra.mul!(v::AbstractVecOrMat, +@generated function LinearAlgebra.mul!(w::AbstractVecOrMat, L::AddedOperator, - u::AbstractVecOrMat, + v::AbstractVecOrMat, α, β) ops_types = L.parameters[2].parameters N = length(ops_types) quote - lmul!(β, v) + lmul!(β, w) Base.@nexprs $(N) i->begin - mul!(v, L.ops[i], u, α, true) + mul!(w, L.ops[i], v, α, true) + end + w + end +end +# Out-of-place: v is action vector, u is update vector +function (L::AddedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + L = update_coefficients(L, u, p, t; kwargs...) + sum(op -> iszero(op) ? zero(v) : op(v, u, p, t; kwargs...), L.ops) +end + +# In-place: w is destination, v is action vector, u is update vector +function (L::AddedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + for op in L.ops + if !iszero(op) + op(w, v, u, p, t, 1.0, 1.0; kwargs...) + end + end + w +end + +# In-place with scaling: w = α*(L*v) + β*w +function (L::AddedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + lmul!(β, w) + for op in L.ops + if !iszero(op) + op(w, v, u, p, t, α, 1.0; kwargs...) end - v end + w end """ Lazy operator composition - ∘(A, B, C)(u) = A(B(C(u))) + ∘(A, B, C)(v) = A(B(C(v))) ops = (A, B, C) - cache = (B*C*u , C*u) + cache = (B*C*v , C*v) """ struct ComposedOperator{T, O, C} <: AbstractSciMLOperator{T} """ Tuple of N operators to be applied in reverse""" @@ -670,8 +796,7 @@ end #Base.:*(L::ComposedOperator, u::AbstractVecOrMat) = foldl((acc, op) -> op * acc, reverse(L.ops); init=u) #Base.:\(L::ComposedOperator, u::AbstractVecOrMat) = foldl((acc, op) -> op \ acc, L.ops; init=u) -function Base.:\(L::ComposedOperator, u::AbstractVecOrMat) - v = u +function Base.:\(L::ComposedOperator, v::AbstractVecOrMat) for op in L.ops v = op \ v end @@ -679,8 +804,7 @@ function Base.:\(L::ComposedOperator, u::AbstractVecOrMat) v end -function Base.:*(L::ComposedOperator, u::AbstractVecOrMat) - v = u +function Base.:*(L::ComposedOperator, v::AbstractVecOrMat) for op in reverse(L.ops) v = op * v end @@ -688,15 +812,15 @@ function Base.:*(L::ComposedOperator, u::AbstractVecOrMat) v end -function cache_self(L::ComposedOperator, u::AbstractVecOrMat) - K = size(u, 2) - cache = (zero(u),) +function cache_self(L::ComposedOperator, v::AbstractVecOrMat) + K = size(v, 2) + cache = (zero(v),) for i in reverse(2:length(L.ops)) op = L.ops[i] M = size(op, 1) - sz = u isa AbstractMatrix ? (M, K) : (M,) + sz = v isa AbstractMatrix ? (M, K) : (M,) T = if op isa FunctionOperator # # FunctionOperator isn't guaranteed to play by the rules of @@ -707,16 +831,16 @@ function cache_self(L::ComposedOperator, u::AbstractVecOrMat) promote_type(eltype.((op, cache[1]))...) end - cache = (similar(u, T, sz), cache...) + cache = (similar(v, T, sz), cache...) end @reset L.cache = cache L end -function cache_internals(L::ComposedOperator, u::AbstractVecOrMat) +function cache_internals(L::ComposedOperator, v::AbstractVecOrMat) if isnothing(L.cache) - L = cache_self(L, u) + L = cache_self(L, v) end ops = () @@ -727,49 +851,84 @@ function cache_internals(L::ComposedOperator, u::AbstractVecOrMat) @reset L.ops = ops end -function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ComposedOperator, u::AbstractVecOrMat) +function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat) @assert iscached(L) """cache needs to be set up for operator of type - $L. Set up cache by calling `cache_operator(L, u)`""" + $L. Set up cache by calling `cache_operator(L, v)`""" - vecs = (v, L.cache[1:(end - 1)]..., u) + vecs = (w, L.cache[1:(end - 1)]..., v) for i in reverse(1:length(L.ops)) mul!(vecs[i], L.ops[i], vecs[i + 1]) end - v + w end -function LinearAlgebra.mul!(v::AbstractVecOrMat, +function LinearAlgebra.mul!(w::AbstractVecOrMat, L::ComposedOperator, - u::AbstractVecOrMat, + v::AbstractVecOrMat, α, β) @assert iscached(L) """cache needs to be set up for operator of type - $L. Set up cache by calling `cache_operator(L, u)`.""" + $L. Set up cache by calling `cache_operator(L, v)`.""" cache = L.cache[end] - copy!(cache, v) + copy!(cache, w) - mul!(v, L, u) - lmul!(α, v) - axpy!(β, cache, v) + mul!(w, L, v) + lmul!(α, w) + axpy!(β, cache, w) end -function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::ComposedOperator, u::AbstractVecOrMat) +function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::ComposedOperator, v::AbstractVecOrMat) @assert iscached(L) """cache needs to be set up for operator of type - $L. Set up cache by calling `cache_operator(L, u)`.""" + $L. Set up cache by calling `cache_operator(L, v)`.""" - vecs = (u, reverse(L.cache[1:(end - 1)])..., v) + vecs = (v, reverse(L.cache[1:(end - 1)])..., w) for i in 1:length(L.ops) ldiv!(vecs[i + 1], L.ops[i], vecs[i]) end - v + w end -function LinearAlgebra.ldiv!(L::ComposedOperator, u::AbstractVecOrMat) +function LinearAlgebra.ldiv!(L::ComposedOperator, v::AbstractVecOrMat) for i in 1:length(L.ops) - ldiv!(L.ops[i], u) + ldiv!(L.ops[i], v) + end + v +end + +# Out-of-place: v is action vector, u is update vector +function (L::ComposedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + L = update_coefficients(L, u, p, t; kwargs...) + result = v + for op in reverse(L.ops) + result = op(result, u, p, t; kwargs...) + end + result +end + +# In-place: w is destination, v is action vector, u is update vector +function (L::ComposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + @assert iscached(L) "Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first." + + vecs = (w, L.cache[1:(end-1)]..., v) + for i in reverse(1:length(L.ops)) + L.ops[i](vecs[i], vecs[i+1], u, p, t; kwargs...) end - u + w +end + +# In-place with scaling: w = α*(L*v) + β*w +function (L::ComposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + @assert iscached(L) "Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first." + + cache = L.cache[end] + copy!(cache, w) + + L(w, v, u, p, t; kwargs...) + lmul!(α, w) + axpy!(β, cache, w) end """ @@ -860,33 +1019,58 @@ function cache_internals(L::InvertedOperator, u::AbstractVecOrMat) L end -function LinearAlgebra.mul!(v::AbstractVecOrMat, L::InvertedOperator, u::AbstractVecOrMat) - ldiv!(v, L.L, u) +function LinearAlgebra.mul!(w::AbstractVecOrMat, L::InvertedOperator, v::AbstractVecOrMat) + ldiv!(w, L.L, v) end -function LinearAlgebra.mul!(v::AbstractVecOrMat, +function LinearAlgebra.mul!(w::AbstractVecOrMat, L::InvertedOperator, - u::AbstractVecOrMat, + v::AbstractVecOrMat, α, β) @assert iscached(L) """cache needs to be set up for operator of type - $L. Set up cache by calling `cache_operator(L, u)`.""" + $L. Set up cache by calling `cache_operator(L, v)`.""" - copy!(L.cache, v) - ldiv!(v, L.L, u) - lmul!(α, v) - axpy!(β, L.cache, v) + copy!(L.cache, w) + ldiv!(w, L.L, v) + lmul!(α, w) + axpy!(β, L.cache, w) end -function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertedOperator, u::AbstractVecOrMat) - mul!(v, L.L, u) +function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::InvertedOperator, v::AbstractVecOrMat) + mul!(w, L.L, v) end -function LinearAlgebra.ldiv!(L::InvertedOperator, u::AbstractVecOrMat) +function LinearAlgebra.ldiv!(L::InvertedOperator, v::AbstractVecOrMat) @assert iscached(L) """cache needs to be set up for operator of type - $L. Set up cache by calling `cache_operator(L, u)`.""" + $L. Set up cache by calling `cache_operator(L, v)`.""" + + copy!(L.cache, v) + mul!(v, L.L, L.cache) +end + +# Out-of-place: v is action vector, u is update vector +function (L::InvertedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + L = update_coefficients(L, u, p, t; kwargs...) + L.L \ v +end + +# In-place: w is destination, v is action vector, u is update vector +function (L::InvertedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + ldiv!(w, L.L, v) + w +end - copy!(L.cache, u) - mul!(u, L.L, L.cache) +# In-place with scaling: w = α*(L*v) + β*w +function (L::InvertedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + @assert iscached(L) "Cache needs to be set up for InvertedOperator. Call cache_operator(L, u) first." + + copy!(L.cache, w) + ldiv!(w, L.L, v) + lmul!(α, w) + axpy!(β, L.cache, w) + w end # diff --git a/src/batch.jl b/src/batch.jl index ea6e8d52..b3cc6ca8 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -153,4 +153,25 @@ function LinearAlgebra.ldiv!(L::BatchedDiagonalOperator, u::AbstractVecOrMat) u end + +function (L::BatchedDiagonalOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + L = update_coefficients(L, u, p, t; kwargs...) + L.diag .* v +end + +function (L::BatchedDiagonalOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + w .= L.diag .* v + return w +end + +function (L::BatchedDiagonalOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + if β == 0 + w .= α .* (L.diag .* v) + else + w .= α .* (L.diag .* v) .+ β .* w + end + return w +end # diff --git a/src/func.jl b/src/func.jl index 0bc0da97..e852adc9 100644 --- a/src/func.jl +++ b/src/func.jl @@ -4,9 +4,9 @@ Matrix free operator given by a function $(FIELDS) """ -mutable struct FunctionOperator{iip, oop, mul5, T <: Number, F, Fa, Fi, Fai, Tr, P, Tt, +mutable struct FunctionOperator{iip, oop, mul5, T <: Number, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} <: AbstractSciMLOperator{T} - """ Function with signature op(u, p, t) and (if isinplace) op(v, u, p, t) """ + """ Function with signature op(v, u, p, t) and (if isinplace) op(w, v, u, p, t) """ op::F """ Adjoint operator""" op_adjoint::Fa @@ -16,6 +16,8 @@ mutable struct FunctionOperator{iip, oop, mul5, T <: Number, F, Fa, Fi, Fai, Tr, op_adjoint_inverse::Fai """ Traits """ traits::Tr + """ State """ + u::U """ Parameters """ p::P """ Time """ @@ -24,7 +26,7 @@ mutable struct FunctionOperator{iip, oop, mul5, T <: Number, F, Fa, Fi, Fai, Tr, cache::C end -function FunctionOperator(op, op_adjoint, op_inverse, op_adjoint_inverse, traits, p, t, +function FunctionOperator(op, op_adjoint, op_inverse, op_adjoint_inverse, traits, u, p, t, cache, ::Type{iType}, ::Type{oType}) where {iType, oType} iip = traits.isinplace oop = traits.outofplace @@ -32,119 +34,141 @@ function FunctionOperator(op, op_adjoint, op_inverse, op_adjoint_inverse, traits T = traits.T return FunctionOperator{iip, oop, mul5, T, typeof(op), typeof(op_adjoint), - typeof(op_inverse), typeof(op_adjoint_inverse), typeof(traits), typeof(p), + typeof(op_inverse), typeof(op_adjoint_inverse), typeof(traits), typeof(u), typeof(p), typeof(t), typeof(cache), iType, oType}(op, op_adjoint, op_inverse, - op_adjoint_inverse, traits, p, t, cache) + op_adjoint_inverse, traits, u, p, t, cache) end function set_op( - f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, + f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType}, - op) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, oType} + op) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} return FunctionOperator{ - iip, oop, mul5, T, typeof(op), Fa, Fi, Fai, Tr, P, Tt, C, iType, - oType}(op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, f.p, f.t, + iip, oop, mul5, T, typeof(op), Fa, Fi, Fai, Tr, U, P, Tt, C, iType, + oType}(op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, f.u, f.p, f.t, f.cache) end function set_op_adjoint( - f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType}, - op_adjoint) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, + op_adjoint) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} - return FunctionOperator{iip, oop, mul5, T, F, typeof(op_adjoint), Fi, Fai, Tr, P, Tt, + return FunctionOperator{iip, oop, mul5, T, F, typeof(op_adjoint), Fi, Fai, Tr, U, P, Tt, C, iType, oType}(f.op, op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, - f.p, f.t, f.cache) + f.u, f.p, f.t, f.cache) end function set_op_inverse( - f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType}, - op_inverse) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, + op_inverse) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} - return FunctionOperator{iip, oop, mul5, T, F, Fa, typeof(op_inverse), Fai, Tr, P, Tt, + return FunctionOperator{iip, oop, mul5, T, F, Fa, typeof(op_inverse), Fai, Tr, U, P, Tt, C, iType, oType}(f.op, f.op_adjoint, op_inverse, f.op_adjoint_inverse, f.traits, - f.p, f.t, f.cache) + f.u, f.p, f.t, f.cache) end function set_op_adjoint_inverse( f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, - P, Tt, C, iType, oType}, + U, P, Tt, C, iType, oType}, op_adjoint_inverse) where {iip, oop, mul5, T, F, Fa, - Fi, Fai, Tr, P, Tt, C, iType, oType} + Fi, Fai, Tr, U, P, Tt, C, iType, oType} return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, typeof(op_adjoint_inverse), Tr, - P, Tt, C, iType, oType}(f.op, f.op_adjoint, f.op_inverse, op_adjoint_inverse, - f.traits, f.p, f.t, f.cache) + U, P, Tt, C, iType, oType}(f.op, f.op_adjoint, f.op_inverse, op_adjoint_inverse, + f.traits, f.u, f.p, f.t, f.cache) end function set_traits( - f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType}, - traits) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + traits) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} - return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, typeof(traits), P, Tt, + return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, typeof(traits), U, P, Tt, C, iType, oType}(f.op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, traits, - f.p, f.t, f.cache) + f.u, f.p, f.t, f.cache) +end + +function set_u( + f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, + iType, oType}, + u) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, + oType} + return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, typeof(u), P, Tt, C, iType, + oType}(f.op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, u, f.p, f.t, + f.cache) end function set_p( - f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType}, - p) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, + p) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} - return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, typeof(p), Tt, C, iType, - oType}(f.op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, p, f.t, + return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, typeof(p), Tt, C, iType, + oType}(f.op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, f.u, p, f.t, f.cache) end function set_t( - f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, + f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType}, - t) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, oType} - return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, typeof(t), C, iType, - oType}(f.op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, f.p, t, + t) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} + return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, typeof(t), C, iType, + oType}(f.op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, f.u, f.p, t, f.cache) end function set_cache( - f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType}, - cache) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + cache) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} - return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, typeof(cache), + return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, typeof(cache), iType, oType}(f.op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, - f.p, f.t, cache) + f.u, f.p, f.t, cache) end -function input_eltype(::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, +function input_eltype(::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType -}) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, oType} +}) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} return iType end -function output_eltype(::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, +function output_eltype(::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType -}) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, oType} +}) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} return oType end """ $(SIGNATURES) -Wrap callable object `op` within an `AbstractSciMLOperator`. `op` -is assumed to have signature +Wrap callable object `op` within an `AbstractSciMLOperator`. + +## Mathematical Description + +```julia +L = FunctionOperator(op, v, w; kwargs...) +``` + +where ``w = L(u,p,t)*v`` is done matrix-free given the function +definition ``w = op(v,u,p,t)``. + +## Arguments + +`op` is assumed to have signature - op(u, p, t; ) -> v + op(v, u, p, t; ) -> w or - op(v, u, p, t; ) -> [modifies v] + op(w, v, u, p, t; ) -> [modifies w] and optionally - op(v, u, p, t, α, β; ) -> [modifies v] + op(w, v, u, p, t, α, β; ) -> [modifies w] -where `u`, `v` are `AbstractArray`s, `p` is a parameter object, and +where `u`, `v`, `w` are `AbstractArray`s, `p` is a parameter object, and `t`, `α`, `β` are scalars. The first signature corresponds to applying the operator with `Base.*`, and the latter two correspond to the three-argument, and the five-argument `mul!` respectively. @@ -154,7 +178,7 @@ determining operator traits such as `eltype`, `size`, and for preallocating cache. If `output` array is not provided, the output is assumed to be of the same type and share as the input. -# Keyword Arguments +## Keyword Arguments Keyword arguments are used to pass in the adjoint evaluation function, `op_adjoint`, the inverse function, `op_inverse`, and the adjoint-inverse @@ -166,6 +190,7 @@ below traits. Keyword arguments are used to set operator traits, which are assumed to be uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`. +* `u` - Prototype of the state struct passed to the operator during evaluation, i.e. `L(u, p, t)`. `u` is set to `nothing` if no value is provided. * `p` - Prototype of parameter struct passed to the operator during evaluation, i.e. `L(u, p, t)`. `p` is set to `nothing` if no value is provided. * `t` - Protype of scalar time variable passed to the operator during evaluation. `t` is set to `zero(T)` if no value is provided. * `accepted_kwargs` - `Tuple` of `Symbol`s corresponding to the keyword arguments accepted by `op*`, and `update_coefficients[!]`. For example, if `op` accepts kwarg `scale`, as in `op(u, p, t; scale)`, then `accepted_kwargs = (:scale,)`. @@ -189,7 +214,7 @@ function FunctionOperator(op, input::AbstractArray, output::AbstractArray = input; op_adjoint = nothing, op_inverse = nothing, - op_adjoint_inverse = nothing, p = nothing, + op_adjoint_inverse = nothing, u = nothing, p = nothing, t::Union{Number, Nothing} = nothing, accepted_kwargs::Union{Nothing, Val, NTuple{N, Symbol}} = nothing, @@ -267,7 +292,7 @@ function FunctionOperator(op, # evaluation signatures _isinplace = if isinplace === nothing - Val(hasmethod(op, typeof((output, input, p, _t)))) + Val(hasmethod(op, typeof((output, input, u, p, _t)))) elseif isinplace isa Bool Val(isinplace) else @@ -275,7 +300,7 @@ function FunctionOperator(op, end _outofplace = if outofplace === nothing - Val(hasmethod(op, typeof((input, p, _t)))) + Val(hasmethod(op, typeof((input, u, p, _t)))) elseif outofplace isa Bool Val(outofplace) else @@ -283,16 +308,16 @@ function FunctionOperator(op, end if !_unwrap_val(_isinplace) & !_unwrap_val(_outofplace) - @error """Please provide a function with signatures `op(u, p, t)` for + @error """Please provide a function with signatures `op(v, u, p, t)` for applying the operator out-of-place, and/or the signature is - `op(v, u, p, t)` for in-place application.""" + `op(w, v, u, p, t)` for in-place application.""" end _has_mul5 = if has_mul5 === nothing - __and_val(__has_mul5(op, output, input, p, _t), - __has_mul5(op_adjoint, input, output, p, _t), - __has_mul5(op_inverse, output, input, p, _t), - __has_mul5(op_adjoint_inverse, input, output, p, _t)) + __and_val(__has_mul5(op, output, input, u, p, _t), + __has_mul5(op_adjoint, input, output, u, p, _t), + __has_mul5(op_inverse, output, input, u, p, _t), + __has_mul5(op_adjoint_inverse, input, output, u, p, _t)) elseif has_mul5 isa Bool Val(has_mul5) else @@ -336,9 +361,9 @@ function FunctionOperator(op, L = FunctionOperator{_unwrap_val(_isinplace), _unwrap_val(_outofplace), _unwrap_val(_has_mul5), _T, typeof(op), typeof(_op_adjoint), typeof(op_inverse), - typeof(_op_adjoint_inverse), typeof(traits), typeof(p), typeof(_t), typeof(cache), + typeof(_op_adjoint_inverse), typeof(traits), typeof(u), typeof(p), typeof(_t), typeof(cache), eltype(input), eltype(output)}(op, - _op_adjoint, op_inverse, _op_adjoint_inverse, traits, p, _t, cache) + _op_adjoint, op_inverse, _op_adjoint_inverse, traits, u, p, _t, cache) # create cache @@ -351,15 +376,16 @@ function FunctionOperator(op, return L_cached end -@inline __has_mul5(::Nothing, y, x, p, t) = Val(true) -@inline function __has_mul5(f::F, y, x, p, t) where {F} - return Val(hasmethod(f, typeof((y, x, p, t, t, t)))) +@inline __has_mul5(::Nothing, w, v, u, p, t) = Val(true) +@inline function __has_mul5(f::F, w, v, u, p, t) where {F} + return Val(hasmethod(f, typeof((w, v, u, p, t, t, t)))) end @inline __and_val(vs...) = mapreduce(_unwrap_val, *, vs) function update_coefficients(L::FunctionOperator, u, p, t; kwargs...) - # update p, t + # update u, p, t + L = set_u(L, u) L = set_p(L, p) L = set_t(L, t) @@ -379,7 +405,8 @@ end function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...) - # update p, t + # update u, p, t + L.u = u L.p = p L.t = t @@ -445,18 +472,18 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray) end # fix method amg bw AbstractArray, AbstractVecOrMat -cache_self(L::FunctionOperator, u::AbstractArray) = _cache_self(L, u) -cache_self(L::FunctionOperator, u::AbstractVecOrMat) = _cache_self(L, u) +cache_self(L::FunctionOperator, v::AbstractArray) = _cache_self(L, v) +cache_self(L::FunctionOperator, v::AbstractVecOrMat) = _cache_self(L, v) function _cache_self( - L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType}, - u::AbstractArray) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, + v::AbstractArray) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} - _u = similar(u, iType, L.traits.sizes[1]) - _v = similar(u, oType, L.traits.sizes[2]) + _v = similar(v, iType, L.traits.sizes[1]) + _w = similar(v, oType, L.traits.sizes[2]) - return set_cache(L, (_u, _v)) + return set_cache(L, (_v, _w)) end # fix method amg bw AbstractArray, AbstractVecOrMat @@ -464,18 +491,18 @@ cache_internals(L::FunctionOperator, u::AbstractArray) = _cache_internals(L, u) cache_internals(L::FunctionOperator, u::AbstractVecOrMat) = _cache_internals(L, u) function _cache_internals( - L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, + L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType}, u::AbstractArray) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, - P, Tt, C, iType, oType} + U, P, Tt, C, iType, oType} newop = cache_operator(L.op, u) newop_adjoint = cache_operator(L.op_adjoint, u) newop_inverse = cache_operator(L.op_inverse, u) newop_adjoint_inverse = cache_operator(L.op_adjoint_inverse, u) return FunctionOperator{iip, oop, mul5, T, typeof(newop), typeof(newop_adjoint), - typeof(newop_inverse), typeof(newop_adjoint_inverse), Tr, P, Tt, C, iType, oType}( - newop, newop_adjoint, newop_inverse, newop_adjoint_inverse, L.traits, L.p, L.t, + typeof(newop_inverse), typeof(newop_adjoint_inverse), Tr, U, P, Tt, C, iType, oType}( + newop, newop_adjoint, newop_inverse, newop_adjoint_inverse, L.traits, L.u, L.p, L.t, L.cache) end @@ -485,10 +512,10 @@ function Base.show(io::IO, L::FunctionOperator) end Base.size(L::FunctionOperator) = L.traits.size -function Base.adjoint(L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, +function Base.adjoint(L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType -}) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, oType} +}) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} (ishermitian(L) | (isreal(L) & issymmetric(L))) && return L has_adjoint(L) || return AdjointOperator(L) @@ -504,16 +531,16 @@ function Base.adjoint(L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, cache = iscached(L) ? reverse(L.cache) : nothing return FunctionOperator{iip, oop, mul5, T, typeof(op), typeof(op_adjoint), - typeof(op_inverse), typeof(op_adjoint_inverse), typeof(traits), P, Tt, + typeof(op_inverse), typeof(op_adjoint_inverse), typeof(traits), U, P, Tt, typeof(cache), oType, iType}( op, op_adjoint, op_inverse, op_adjoint_inverse, traits, - L.p, L.t, cache) + L.u, L.p, L.t, cache) end -function Base.inv(L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, +function Base.inv(L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType -}) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, oType} +}) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, U, P, Tt, C, iType, oType} has_ldiv(L) || return InvertedOperator(L) op = L.op_inverse @@ -535,10 +562,10 @@ function Base.inv(L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, cache = iscached(L) ? reverse(L.cache) : nothing return FunctionOperator{iip, oop, mul5, T, typeof(op), typeof(op_adjoint), - typeof(op_inverse), typeof(op_adjoint_inverse), typeof(traits), P, Tt, + typeof(op_inverse), typeof(op_adjoint_inverse), typeof(traits), U, P, Tt, typeof(cache), oType, iType}( op, op_adjoint, op_inverse, op_adjoint_inverse, traits, - L.p, L.t, cache) + L.u, L.p, L.t, cache) end Base.convert(::Type{AbstractMatrix}, L::FunctionOperator) = convert(AbstractMatrix, L.op) @@ -596,133 +623,131 @@ has_mul!(::FunctionOperator{iip}) where {iip} = iip has_ldiv(L::FunctionOperator{iip}) where {iip} = !(L.op_inverse isa Nothing) has_ldiv!(L::FunctionOperator{iip}) where {iip} = iip & !(L.op_inverse isa Nothing) -function _sizecheck(L::FunctionOperator, u, v) +function _sizecheck(L::FunctionOperator, v, w) sizes = L.traits.sizes if L.traits.batch - if !isnothing(u) - if !isa(u, AbstractVecOrMat) + if !isnothing(v) + if !isa(v, AbstractVecOrMat) msg = """$L constructed with `batch = true` only accept input arrays that are `AbstractVecOrMat`s with - `size(L, 2) == size(u, 1)`. Received $(typeof(u)).""" + `size(L, 2) == size(v, 1)`. Received $(typeof(v)).""" throw(ArgumentError(msg)) end - if size(L, 2) != size(u, 1) + if size(L, 2) != size(v, 1) msg = """$L accepts input `AbstractVecOrMat`s of size - ($(size(L, 2)), K). Received array of size $(size(u)).""" + ($(size(L, 2)), K). Received array of size $(size(v)).""" throw(DimensionMismatch(msg)) end - end # u + end # v - if !isnothing(v) - if !isa(v, AbstractVecOrMat) + if !isnothing(w) + if !isa(w, AbstractVecOrMat) msg = """$L constructed with `batch = true` only returns output arrays that are `AbstractVecOrMat`s with - `size(L, 1) == size(v, 1)`. Received $(typeof(v)).""" + `size(L, 1) == size(w, 1)`. Received $(typeof(w)).""" throw(ArgumentError(msg)) end - if size(L, 1) != size(v, 1) + if size(L, 1) != size(w, 1) msg = """$L accepts output `AbstractVecOrMat`s of size - ($(size(L, 1)), K). Received array of size $(size(v)).""" + ($(size(L, 1)), K). Received array of size $(size(w)).""" throw(DimensionMismatch(msg)) end - end # v + end # w - if !isnothing(u) & !isnothing(v) - if size(u, 2) != size(v, 2) - msg = """input array $u, and output array, $v, must have the + if !isnothing(v) & !isnothing(w) + if size(v, 2) != size(w, 2) + msg = """input array $v, and output array, $w, must have the same batch size (i.e. length of second dimension). Got - $(size(u)), $(size(v)). If you encounter this error during + $(size(v)), $(size(w)). If you encounter this error during an in-place evaluation (`LinearAlgebra.mul!`, `ldiv!`), ensure that the operator $L has been cached with an input array of the correct size. Do so by calling - `L = cache_operator(L, u)`.""" + `L = cache_operator(L, v)`.""" throw(DimensionMismatch(msg)) end - end # u, v + end # v, w else # !batch - if !isnothing(u) - if size(u) ∉ (sizes[1], tuple(size(L, 2))) - msg = """$L received input array of size $(size(u)), but only + if !isnothing(v) + if size(v) ∉ (sizes[1], tuple(size(L, 2))) + msg = """$L received input array of size $(size(v)), but only accepts input arrays of size $(sizes[1]), or vectors like - `vec(u)` of size $(tuple(prod(sizes[1]))).""" + `vec(v)` of size $(tuple(prod(sizes[1]))).""" throw(DimensionMismatch(msg)) end - end # u + end # v - if !isnothing(v) - if size(v) ∉ (sizes[2], tuple(size(L, 1))) - msg = """$L received output array of size $(size(v)), but only + if !isnothing(w) + if size(w) ∉ (sizes[2], tuple(size(L, 1))) + msg = """$L received output array of size $(size(w)), but only accepts output arrays of size $(sizes[2]), or vectors like - `vec(u)` of size $(tuple(prod(sizes[2])))""" + `vec(v)` of size $(tuple(prod(sizes[2])))""" throw(DimensionMismatch(msg)) end - end # v + end # w end # batch return end -function _unvec(L::FunctionOperator, u, v) +function _unvec(L::FunctionOperator, v, w) if L.traits.batch - return u, v, false + return v, w, false else sizes = L.traits.sizes # no need to vec since expected input/output are AbstractVectors if length(sizes[1]) == 1 - return u, v, false + return v, w, false end - vec_u = isnothing(u) ? false : size(u) != sizes[1] - vec_v = isnothing(v) ? false : size(v) != sizes[2] + vec_v = isnothing(v) ? false : size(v) != sizes[1] + vec_w = isnothing(w) ? false : size(w) != sizes[2] - if !isnothing(u) & !isnothing(v) - if (vec_u & !vec_v) | (!vec_u & vec_v) + if !isnothing(v) & !isnothing(w) + if (vec_v & !vec_w) | (!vec_v & vec_w) msg = """Input / output to $L can either be of sizes $(sizes[1]) / $(sizes[2]), or $(tuple(prod(sizes[1]))) / $(tuple(prod(sizes[2]))). Got - $(size(u)), $(size(v)).""" + $(size(v)), $(size(w)).""" throw(DimensionMismatch(msg)) end end - U = vec_u ? reshape(u, sizes[1]) : u - V = vec_v ? reshape(v, sizes[2]) : v - vec_output = vec_u | vec_v + V = vec_v ? reshape(v, sizes[1]) : v + W = vec_w ? reshape(w, sizes[2]) : w + vec_output = vec_v | vec_w - return U, V, vec_output + return V, W, vec_output end end # operator application -function Base.:*(L::FunctionOperator{iip, true}, u::AbstractArray) where {iip} - _sizecheck(L, u, nothing) - U, _, vec_output = _unvec(L, u, nothing) +function Base.:*(L::FunctionOperator{iip, true}, v::AbstractArray) where {iip} + _sizecheck(L, v, nothing) + V, _, vec_output = _unvec(L, v, nothing) - V = L.op(U, L.p, L.t; L.traits.kwargs...) + W = L.op(V, L.u, L.p, L.t; L.traits.kwargs...) - vec_output ? vec(V) : V + vec_output ? vec(W) : W end function Base.:\(L::FunctionOperator{iip, true}, v::AbstractArray) where {iip} _sizecheck(L, nothing, v) _, V, vec_output = _unvec(L, nothing, v) - U = L.op_inverse(V, L.p, L.t; L.traits.kwargs...) + W = L.op_inverse(V, L.u, L.p, L.t; L.traits.kwargs...) - vec_output ? vec(U) : U + vec_output ? vec(W) : W end -function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true}, u::AbstractArray) - _sizecheck(L, u, v) - U, V, vec_output = _unvec(L, u, v) - - L.op(V, U, L.p, L.t; L.traits.kwargs...) - - vec_output ? vec(V) : V +function LinearAlgebra.mul!(w::AbstractArray, L::FunctionOperator{true}, v::AbstractArray) + _sizecheck(L, v, w) + V, W, vec_output = _unvec(L, v, w) + L.op(W, V, L.u, L.p, L.t; L.traits.kwargs...) + vec_output ? vec(W) : W end function LinearAlgebra.mul!(::AbstractArray, L::FunctionOperator{false}, ::AbstractArray, @@ -730,49 +755,49 @@ function LinearAlgebra.mul!(::AbstractArray, L::FunctionOperator{false}, ::Abstr @error "LinearAlgebra.mul! not defined for out-of-place operator $L" end -function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true, oop, false}, - u::AbstractArray, α, β) where {oop} +function LinearAlgebra.mul!(w::AbstractArray, L::FunctionOperator{true, oop, false}, + v::AbstractArray, α, β) where {oop} _, Co = L.cache - _sizecheck(L, u, v) - U, V, _ = _unvec(L, u, v) + _sizecheck(L, v, w) + V, W, _ = _unvec(L, v, w) - copy!(Co, V) - L.op(V, U, L.p, L.t; L.traits.kwargs...) # mul!(V, L, U) - axpby!(β, Co, α, V) + copy!(Co, W) + L.op(W, V, L.u, L.p, L.t; L.traits.kwargs...) # mul!(V, L, U) + axpby!(β, Co, α, W) - v + w end -function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true, oop, true}, - u::AbstractArray, α, β) where {oop} - _sizecheck(L, u, v) - U, V, _ = _unvec(L, u, v) +function LinearAlgebra.mul!(w::AbstractArray, L::FunctionOperator{true, oop, true}, + v::AbstractArray, α, β) where {oop} + _sizecheck(L, v, w) + V, W, _ = _unvec(L, v, w) - L.op(V, U, L.p, L.t, α, β; L.traits.kwargs...) + L.op(W, V, L.u, L.p, L.t, α, β; L.traits.kwargs...) - v + w end -function LinearAlgebra.ldiv!(u::AbstractArray, L::FunctionOperator{true}, v::AbstractArray) - _sizecheck(L, u, v) - U, V, _ = _unvec(L, u, v) +function LinearAlgebra.ldiv!(w::AbstractArray, L::FunctionOperator{true}, v::AbstractArray) + _sizecheck(L, v, w) + W, V, _ = _unvec(L, w, v) - L.op_inverse(U, V, L.p, L.t; L.traits.kwargs...) + L.op_inverse(W, V, L.u, L.p, L.t; L.traits.kwargs...) - u + w end -function LinearAlgebra.ldiv!(L::FunctionOperator{true}, u::AbstractArray) - V, _ = L.cache +function LinearAlgebra.ldiv!(L::FunctionOperator{true}, v::AbstractArray) + W, _ = L.cache - _sizecheck(L, u, V) - U, _, _ = _unvec(L, u, nothing) + _sizecheck(L, nothing, v) + V, _, vec_output = _unvec(L, v, nothing) - copy!(V, U) - L.op_inverse(U, V, L.p, L.t; L.traits.kwargs...) # ldiv!(U, L, V) + copy!(W, V) + L.op_inverse(W, V, L.u, L.p, L.t; L.traits.kwargs...) # ldiv!(U, L, V) - u + vec_output ? vec(W) : W end function LinearAlgebra.ldiv!(v::AbstractArray, L::FunctionOperator{false}, u::AbstractArray) @@ -782,4 +807,70 @@ end function LinearAlgebra.ldiv!(L::FunctionOperator{false}, u::AbstractArray) @error "LinearAlgebra.ldiv! not defined for out-of-place $L" end + +# Out-of-place: v is action vector, u is update vector +function (L::FunctionOperator)(v::AbstractArray, u, p, t; kwargs...) + L = update_coefficients(L, u, p, t; kwargs...) + _sizecheck(L, v, nothing) + V, _, vec_output = _unvec(L, v, nothing) + + # Apply the operator to action vector v after updating with u + if L.traits.outofplace + result = L.op(V, L.u, L.p, L.t; L.traits.kwargs...) + return vec_output ? vec(result) : result + else + # For operators without out-of-place methods, use their in-place methods with a temporary + Co = similar(V) + L.op(Co, V, L.u, L.p, L.t; L.traits.kwargs...) + return vec_output ? vec(Co) : Co + end + + v +end + +# In-place: w is destination, v is action vector, u is update vector +function (L::FunctionOperator)(w::AbstractArray, v::AbstractArray, u, p, t; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + + # Check dimensions + _sizecheck(L, v, w) + V, W, _ = _unvec(L, v, w) + + # Apply the operator in-place to action vector v after updating with u + if L.traits.isinplace + L.op(W, V, L.u, L.p, L.t; L.traits.kwargs...) + else + # For operators without in-place methods, use their out-of-place methods + result = L.op(V, L.u, L.p, L.t; L.traits.kwargs...) + copyto!(W, result) + end + + return w +end + +# In-place with scaling: w = α*(L*v) + β*w +function (L::FunctionOperator)(w::AbstractArray, v::AbstractArray, u, p, t, α, β; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + + # Check dimensions + _sizecheck(L, v, w) + V, W, _ = _unvec(L, v, w) + + # Apply the operator in-place to action vector v with scaling + if L.traits.isinplace && L.traits.has_mul5 + # Direct 5-arg mul! if supported + L.op(W, V, L.u, L.p, L.t, α, β; L.traits.kwargs...) + elseif L.traits.isinplace + # Use temporary for regular in-place + temp = copy(W) + L.op(W, V, L.u, L.p, L.t; L.traits.kwargs...) + axpby!(β, temp, α, W) + else + # Out-of-place with scaling + result = L.op(V, L.u, L.p, L.t; L.traits.kwargs...) + axpby!(β, W, α, result) + end + + return w +end # diff --git a/src/interface.jl b/src/interface.jl index ab025fad..2aebd262 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -43,7 +43,7 @@ $(UPDATE_COEFFS_WARNING) # Example ``` -using SciMLOperator +using SciMLOperators mat_update_func = (A, u, p, t; scale = 1.0) -> p * p' * scale * t @@ -56,8 +56,12 @@ u = rand(4) p = rand(4) t = 1.0 +# Update the operator to `(u,p,t)` and apply it to `v` L = update_coefficients(L, u, p, t; scale = 2.0) -L * u +result = L * v + +# Or use the interface which separrates the update from the application +result = L(v, u, p, t; scale = 2.0) ``` """ @@ -78,7 +82,7 @@ $(UPDATE_COEFFS_WARNING) # Example ``` -using SciMLOperator +using SciMLOperators _A = rand(4, 4) mat_update_func! = (L, u, p, t; scale = 1.0) -> copy!(A, _A) @@ -92,7 +96,7 @@ p = rand(4) t = 1.0 update_coefficients!(L, u, p, t) -L * u +L * v ``` """ @@ -109,18 +113,21 @@ end # operator evaluation interface ### -function (L::AbstractSciMLOperator)(u, p, t; kwargs...) - update_coefficients(L, u, p, t; kwargs...) * u +# Out-of-place: v is action vector, u is update vector +function (L::AbstractSciMLOperator)(v, u, p, t; kwargs...) + update_coefficients(L, u, p, t; kwargs...) * v end -function (L::AbstractSciMLOperator)(du, u, p, t; kwargs...) - (update_coefficients!(L, u, p, t; kwargs...); mul!(du, L, u)) +# In-place: w is destination, v is action vector, u is update vector +function (L::AbstractSciMLOperator)(w, v, u, p, t; kwargs...) + (update_coefficients!(L, u, p, t; kwargs...); mul!(w, L, v)) end -function (L::AbstractSciMLOperator)(du, u, p, t, α, β; kwargs...) - (update_coefficients!(L, u, p, t; kwargs...); mul!(du, L, u, α, β)) +# In-place with scaling: w = α*(L*v) + β*w +function (L::AbstractSciMLOperator)(w, v, u, p, t, α, β; kwargs...) + (update_coefficients!(L, u, p, t; kwargs...); mul!(w, L, v, α, β)) end -function (L::AbstractSciMLOperator)(du::Number, u::Number, p, t, args...; kwargs...) - msg = """Nonallocating L(v, u, p, t) type methods are not available for +function (L::AbstractSciMLOperator)(w::Number, v::Number, u, p, t, args...; kwargs...) + msg = """Nonallocating L(w, v, u, p, t) type methods are not available for subtypes of `Number`.""" throw(ArgumentError(msg)) end @@ -193,14 +200,14 @@ has_adjoint(L::AbstractSciMLOperator) = false # L', adjoint(L) """ $SIGNATURES -Check if `expmv!(v, L, u, t)`, equivalent to `mul!(v, exp(t * A), u)`, is -defined for `Number` `t`, and `AbstractArray`s `u, v` of appropriate sizes. +Check if `expmv!(w, L, v, t)`, equivalent to `mul!(w, exp(t * A), v)`, is +defined for `Number` `t`, and `AbstractArray`s `w, v` of appropriate sizes. """ has_expmv!(L::AbstractSciMLOperator) = false # expmv!(v, L, t, u) """ $SIGNATURES -Check if `expmv(L, u, t)`, equivalent to `exp(t * A) * u`, is defined for +Check if `expmv(L, v, t)`, equivalent to `exp(t * A) * v`, is defined for `Number` `t`, and `AbstractArray` `u` of appropriate size. """ has_expmv(L::AbstractSciMLOperator) = false # v = exp(L, t, u) @@ -213,26 +220,26 @@ has_exp(L::AbstractSciMLOperator) = islinear(L) """ $SIGNATURES -Check if `L * u` is defined for `AbstractArray` `u` of appropriate size. +Check if `L * v` is defined for `AbstractArray` `u` of appropriate size. """ has_mul(L::AbstractSciMLOperator) = true # du = L*u """ $SIGNATURES -Check if `mul!(v, L, u)` is defined for `AbstractArray`s `u, v` of +Check if `mul!(w, L, v)` is defined for `AbstractArray`s `w, v` of appropriate sizes. """ has_mul!(L::AbstractSciMLOperator) = true # mul!(du, L, u) """ $SIGNATURES -Check if `L \\ u` is defined for `AbstractArray` `u` of appropriate size. +Check if `L \\ v` is defined for `AbstractArray` `v` of appropriate size. """ has_ldiv(L::AbstractSciMLOperator) = false # du = L\u """ $SIGNATURES -Check if `ldiv!(v, L, u)` is defined for `AbstractArray`s `u, v` of +Check if `ldiv!(w, L, v)` is defined for `AbstractArray`s `w, v` of appropriate sizes. """ has_ldiv!(L::AbstractSciMLOperator) = false # ldiv!(du, L, u) diff --git a/src/left.jl b/src/left.jl index b176e0d4..90055fa4 100644 --- a/src/left.jl +++ b/src/left.jl @@ -157,4 +157,56 @@ for (op, LType, VType) in ((:adjoint, :AdjointOperator, :AbstractAdjointVecOrMat u end end + + +# For AdjointOperator +# Out-of-place: v is action vector, u is update vector +function (L::AdjointOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + # Adjoint operator applied to v means L.L' * v + # For matrices: (A')v = (v'A)' + # This means we need to compute L.L(v', u, p, t)' + # Update the operator first, then apply adjoint operator + L_updated = update_coefficients(L.L, u, p, t; kwargs...) + # (A')v = (v'A)' where v'A is computed by A'*v' + return (L_updated' * v')' +end + +# In-place: w is destination, v is action vector, u is update vector +function (L::AdjointOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + # Update the operator in-place + update_coefficients!(L.L, u, p, t; kwargs...) + # Use direct in-place multiplicatieon for adjoints + mul!(w', v', L.L) + return w +end + +# In-place with scaling: w = α*(L*v) + β*w +function (L::AdjointOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + # Update the operator in-place + update_coefficients!(L.L, u, p ,t; kwargs...) + mul!(w', v', L.L, α, β) + return w +end + +# For TransposedOperator +# Out-of-place +function (L::TransposedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + L_updated = update_coefficients(L.L, u, p, t; kwargs...) + # (A^T)v = (v'A)' where v'A is computed by A'*v' + return (L_updated' * v')' +end + +# In-place +function (L::TransposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + update_coefficients!(L.L, u, p, t; kwargs...) + mul!(w', v', L.L) + return w +end + +# In-place with scaling +function (L::TransposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + update_coefficients!(L.L, u, p, t; kwargs...) + mul!(w', v', L.L, α, β) + return w +end # diff --git a/src/matrix.jl b/src/matrix.jl index 6fcc10c6..a04c65b5 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -4,7 +4,7 @@ $SIGNATURES Represents a linear operator given by an `AbstractMatrix` that may be applied to an `AbstractVecOrMat`. Its state is updated by the user-provided -`update_func` during operator evaluation (`L([v,], u, p, t)`), or by calls +`update_func` during operator evaluation (`L([w,], v, u, p, t)`), or by calls to `update_coefficients[!](L, u, p, t)`. Both recursively call the `update_function`, `update_func` which is assumed to have the signature @@ -30,33 +30,48 @@ adjoints, transposes. Out-of-place update and usage ``` +v = rand(4) u = rand(4) p = rand(4, 4) t = rand() mat_update = (A, u, p, t; scale = 0.0) -> t * p -M = MatrixOperator(0.0; update_func = mat_update; accepted_kwargs = (:scale,)) +M = MatrixOperator(0.0; update_func = mat_update, accepted_kwargs = (:scale,)) L = M * M + 3I -L = cache_operator(M, u) +L = cache_operator(L, v) -# update L and evaluate -v = L(u, p, t; scale = 1.0) +# update and evaluate +w = L(v, u, p, t; scale = 1.0) + +# In-place evaluation +w = similar(v) +L(w, v, u, p, t; scale = 1.0) + +# In-place with scaling +β = 0.5 +L(w, v, u, p, t, 2.0, β; scale = 1.0) # w = 2.0*(L*v) + 0.5*w ``` In-place update and usage ``` -v = zero(4) +w = zeros(4) +v = zeros(4) u = rand(4) -p = nothing +p = rand(4) # Must be non-nothing t = rand() -mat_update! = (A, u, p, t; scale = 0.0) -> (copy!(A, p); lmul!(t, A)) -M = MatrixOperator(zeros(4, 4); update_func! = val_update!; accepted_kwargs = (:scale,)) +mat_update! = (A, u, p, t; scale = 0.0) -> (A .= t * p * u' * scale) +M = MatrixOperator(zeros(4, 4); update_func! = mat_update!, accepted_kwargs = (:scale,)) L = M * M + 3I +L = cache_operator(L, v) # update L in-place and evaluate -L(v, u, p, t; scale = 1.0) +update_coefficients!(L, u, p, t; scale = 1.0) +mul!(w, L, v) + +# Or use the new interface that separates update and application +L(w, v, u, p, t; scale = 1.0) ``` """ struct MatrixOperator{T, AT <: AbstractMatrix{T}, F, F!} <: AbstractSciMLOperator{T} @@ -149,7 +164,7 @@ function Base.conj(L::MatrixOperator) accepted_kwargs = NoKwargFilter()) end -has_adjoint(A::MatrixOperator) = has_adjoint(A.A) +has_adjoint(L::MatrixOperator) = has_adjoint(L.A) getops(L::MatrixOperator) = (L.A,) function isconstant(L::MatrixOperator) update_func_isconstant(L.update_func) & update_func_isconstant(L.update_func!) @@ -165,6 +180,25 @@ function update_coefficients!(L::MatrixOperator, u, p, t; kwargs...) nothing end +# Out-of-place: v is action vector, u is update vector +function (L::MatrixOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + L = update_coefficients(L, u, p, t; kwargs...) + L.A * v +end + +# In-place: w is destination, v is action vector, u is update vector +function (L::MatrixOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + mul!(w, L.A, v) +end + +# In-place with scaling: w = α*(L*v) + β*w +function (L::MatrixOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + mul!(w, L.A, v, α, β) +end + + # TODO - add tests for MatrixOperator indexing # propagate_inbounds here for the getindex fallback Base.@propagate_inbounds Base.convert(::Type{AbstractMatrix}, L::MatrixOperator) = convert( @@ -194,23 +228,23 @@ function Base.copy(L::MatrixOperator) end # operator application -Base.:*(L::MatrixOperator, u::AbstractVecOrMat) = L.A * u -Base.:\(L::MatrixOperator, u::AbstractVecOrMat) = L.A \ u +Base.:*(L::MatrixOperator, v::AbstractVecOrMat) = L.A * v +Base.:\(L::MatrixOperator, v::AbstractVecOrMat) = L.A \ v @inline function LinearAlgebra.mul!( - v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat) - mul!(v, L.A, u) + w::AbstractVecOrMat, L::MatrixOperator, v::AbstractVecOrMat) + mul!(w, L.A, v) end -@inline function LinearAlgebra.mul!(v::AbstractVecOrMat, +@inline function LinearAlgebra.mul!(w::AbstractVecOrMat, L::MatrixOperator, - u::AbstractVecOrMat, + v::AbstractVecOrMat, α, β) - mul!(v, L.A, u, α, β) + mul!(w, L.A, v, α, β) end -function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat) - ldiv!(v, L.A, u) +function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::MatrixOperator, v::AbstractVecOrMat) + ldiv!(w, L.A, v) end -LinearAlgebra.ldiv!(L::MatrixOperator, u::AbstractVecOrMat) = ldiv!(L.A, u) +LinearAlgebra.ldiv!(L::MatrixOperator, v::AbstractVecOrMat) = ldiv!(L.A, v) """ $SIGNATURES @@ -218,16 +252,16 @@ $SIGNATURES Represents an elementwise scaling (diagonal-scaling) operation that may be applied to an `AbstractVecOrMat`. When `diag` is an `AbstractVector` of length N, `L = DiagonalOperator(diag, ...)` can be applied to -`AbstractArray`s with `size(u, 1) == N`. Each column of the `u` will be -scaled by `diag`, as in `LinearAlgebra.Diagonal(diag) * u`. +`AbstractArray`s with `size(u, 1) == N`. Each column of the `v` will be +scaled by `diag`, as in `LinearAlgebra.Diagonal(diag) * v`. When `diag` is a multidimensional array, `L = DiagonalOperator(diag, ...)` forms an operator of size `(N, N)` where `N = size(diag, 1)` is the leading length of `diag`. -`L` then is the elementwise-scaling operation on arrays of `length(u) = length(diag)` +`L` then is the elementwise-scaling operation on arrays of `length(v) = length(diag)` with leading length `size(u, 1) = N`. Its state is updated by the user-provided `update_func` during operator -evaluation (`L([v,], u, p, t)`), or by calls to +evaluation (`L([w,], v, u, p, t)`), or by calls to `update_coefficients[!](L, u, p, t)`. Both recursively call the `update_function`, `update_func` which is assumed to have the signature @@ -341,6 +375,11 @@ function update_coefficients(L::InvertibleOperator, u, p, t) @reset L.F = update_coefficients(L.F, u, p, t) L end +function update_coefficients!(L::InvertibleOperator, u, p, t; kwargs...) + update_coefficients!(L.L, u, p, t; kwargs...) + update_coefficients!(L.F, u, p, t; kwargs...) + nothing +end getops(L::InvertibleOperator) = (L.L, L.F) islinear(L::InvertibleOperator) = islinear(L.L) @@ -361,40 +400,58 @@ isconvertible(L::InvertibleOperator) = isconvertible(L.L) has_ldiv(L::InvertibleOperator) = has_mul(L.F) has_ldiv!(L::InvertibleOperator) = has_ldiv!(L.F) -function cache_internals(L::InvertibleOperator, u::AbstractVecOrMat) - @reset L.L = cache_operator(L.L, u) - @reset L.F = cache_operator(L.F, u) +function cache_internals(L::InvertibleOperator, v::AbstractVecOrMat) + @reset L.L = cache_operator(L.L, v) + @reset L.F = cache_operator(L.F, v) L end # operator application -Base.:*(L::InvertibleOperator, x::AbstractVecOrMat) = L.L * x -Base.:\(L::InvertibleOperator, x::AbstractVecOrMat) = L.F \ x -function LinearAlgebra.mul!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOrMat) - mul!(v, L.L, u) +Base.:*(L::InvertibleOperator, v::AbstractVecOrMat) = L.L * v +Base.:\(L::InvertibleOperator, v::AbstractVecOrMat) = L.F \ v +function LinearAlgebra.mul!(w::AbstractVecOrMat, L::InvertibleOperator, v::AbstractVecOrMat) + mul!(w, L.L, v) end -function LinearAlgebra.mul!(v::AbstractVecOrMat, +function LinearAlgebra.mul!(w::AbstractVecOrMat, L::InvertibleOperator, - u::AbstractVecOrMat, + v::AbstractVecOrMat, α, β) - mul!(v, L.L, u, α, β) + mul!(w, L.L, v, α, β) end -function LinearAlgebra.ldiv!(v::AbstractVecOrMat, +function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::InvertibleOperator, - u::AbstractVecOrMat) - ldiv!(v, L.F, u) + v::AbstractVecOrMat) + ldiv!(w, L.F, v) end LinearAlgebra.ldiv!(L::InvertibleOperator, u::AbstractVecOrMat) = ldiv!(L.F, u) +# Out-of-place: v is action vector, u is update vector +function (L::InvertibleOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + L = update_coefficients(L, u, p, t; kwargs...) + L.L * v +end + +# In-place: w is destination, v is action vector, u is update vector +function (L::InvertibleOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + mul!(w, L.L, v) +end + +# In-place with scaling: w = α*(L*v) + β*w +function (L::InvertibleOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + mul!(w, L.L, v, α, β) +end + """ $SIGNATURES -Represents a generalized affine operation (`v = A * u + B * b`) that may +Represents a generalized affine operation (`w = A * v + B * b`) that may be applied to an `AbstractVecOrMat`. The user-provided update functions, `update_func[!]` update the `AbstractVecOrMat` `b`, and are called -during operator evaluation (`L([v,], u, p, t)`), or by calls +during operator evaluation (`L([w,], v, u, p, t)`), or by calls to `update_coefficients[!](L, u, p, t)`. The update functions are assumed to have the syntax @@ -404,8 +461,8 @@ or update_func!(b::AbstractVecOrMat, u ,p , t; ) -> [modifies b] and `B`, `b` are expected to have an appropriate size so that -`A * u + B * b` makes sense. Specifically, `size(A, 1) == size(B, 1)`, and -`size(u, 2) == size(b, 2)`. +`A * v + B * b` makes sense. Specifically, `size(A, 1) == size(B, 1)`, and +`size(v, 2) == size(b, 2)`. The set of keyword-arguments accepted by `update_func[!]` must be provided to `AffineOperator` via the kwarg `accepted_kwargs` as a tuple of `Symbol`s. @@ -415,6 +472,7 @@ are not provided. # Example ``` +v = rand(4) u = rand(4) p = rand(4) t = rand() @@ -422,12 +480,12 @@ t = rand() A = MatrixOperator(rand(4, 4)) B = MatrixOperator(rand(4, 4)) -vec_update_func = (b, u, p, t) -> p * t +vec_update_func = (b, u, p, t) -> p .* u * t L = AffineOperator(A, B, zero(4); update_func = vec_update_func) -L = cache_operator(M, u) +L = cache_operator(M, v) # update L and evaluate -v = L(u, p, t) # == A * u + B * (p * t) +w = L(v, u, p, t) # == A * v + B * (p .* u * t) ``` """ @@ -477,7 +535,7 @@ end """ $SIGNATURES -Represents the affine operation `v = I * u + I * b`. The update functions, +Represents the affine operation `w = I * v + I * b`. The update functions, `update_func[!]` update the state of `AbstractVecOrMat ` `b`. See documentation of `AffineOperator` for more details. """ @@ -497,7 +555,7 @@ end """ $SIGNATURES -Represents the affine operation `v = I * u + B * b`. The update functions, +Represents the affine operation `w = I * v + B * b`. The update functions, `update_func[!]` update the state of `AbstractVecOrMat ` `b`. See documentation of `AffineOperator` for more details. """ @@ -583,37 +641,62 @@ function cache_internals(L::AffineOperator, u::AbstractVecOrMat) end # operator application -function Base.:*(L::AffineOperator, u::AbstractVecOrMat) - @assert size(L.b, 2) == size(u, 2) - (L.A * u) + (L.B * L.b) +function Base.:*(L::AffineOperator, v::AbstractVecOrMat) + @assert size(L.b, 2) == size(v, 2) + (L.A * v) + (L.B * L.b) end -function Base.:\(L::AffineOperator, u::AbstractVecOrMat) - @assert size(L.b, 2) == size(u, 2) - L.A \ (u - (L.B * L.b)) +function Base.:\(L::AffineOperator, v::AbstractVecOrMat) + @assert size(L.b, 2) == size(v, 2) + L.A \ (v - (L.B * L.b)) end -function LinearAlgebra.mul!(v::AbstractVecOrMat, L::AffineOperator, u::AbstractVecOrMat) - mul!(v, L.B, L.b) - mul!(v, L.A, u, true, true) +function LinearAlgebra.mul!(w::AbstractVecOrMat, L::AffineOperator, v::AbstractVecOrMat) + mul!(w, L.B, L.b) + mul!(w, L.A, v, true, true) end -function LinearAlgebra.mul!(v::AbstractVecOrMat, +function LinearAlgebra.mul!(w::AbstractVecOrMat, L::AffineOperator, - u::AbstractVecOrMat, + v::AbstractVecOrMat, α, β) - mul!(v, L.B, L.b, α, β) - mul!(v, L.A, u, α, true) + mul!(w, L.B, L.b, α, β) + mul!(w, L.A, v, α, true) +end + +function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::AffineOperator, v::AbstractVecOrMat) + copy!(w, v) + ldiv!(L, w) +end + +function LinearAlgebra.ldiv!(L::AffineOperator, v::AbstractVecOrMat) + mul!(v, L.B, L.b, -1, 1) + ldiv!(L.A, v) +end +# Out-of-place: v is action vector, u is update vector +function (L::AffineOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + L = update_coefficients(L, u, p, t; kwargs...) + (L.A * v) + (L.B * L.b) end -function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::AffineOperator, u::AbstractVecOrMat) - copy!(v, u) - ldiv!(L, v) +# In-place: w is destination, v is action vector, u is update vector +function (L::AffineOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + # First calculate A * v + mul!(w, L.A, v) + # Then add B * b + mul!(w, L.B, L.b, true, true) end -function LinearAlgebra.ldiv!(L::AffineOperator, u::AbstractVecOrMat) - mul!(u, L.B, L.b, -1, 1) - ldiv!(L.A, u) +# In-place with scaling: w = α*(L*v) + β*w +function (L::AffineOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + # Scale the existing w by β + lmul!(β, w) + # Add α * (A * v) + mul!(w, L.A, v, α, true) + # Add α * (B * b) + mul!(w, L.B, L.b, α, true) end # diff --git a/src/scalar.jl b/src/scalar.jl index c228b946..b453e501 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -3,9 +3,9 @@ # AbstractSciMLScalarOperator interface ### -function (L::AbstractSciMLScalarOperator)(u::Number, p, t; kwargs...) +function (L::AbstractSciMLScalarOperator)(v::Number, u, p, t; kwargs...) L = update_coefficients(L, u, p, t; kwargs...) - convert(Number, L) * u + convert(Number, L) * v end SCALINGNUMBERTYPES = (:AbstractSciMLScalarOperator, @@ -112,7 +112,7 @@ $SIGNATURES Represents a linear scaling operator that may be applied to a `Number`, or an `AbstractArray` subtype. Its state is updated by the user-provided -`update_func` during operator evaluation (`L([v,] u, p, t)`), or by +`update_func` during operator evaluation (`L([w,] v, u, p, t)`), or by calls to `update_coefficients[!]`. Both recursively call the update function, `update_func` which is assumed to have the signature: @@ -133,20 +133,27 @@ interface supports lazy addition, subtraction, multiplication and division. # Example ``` -v = zero(4) +v = rand(4) u = rand(4) +w = zeros(4) p = nothing t = 0.0 -val_update = (a, u, p, t; scale = 0.0) -> copy(scale) -α = ScalarOperator(0.0; update_func = val_update; accepted_kwargs = (:scale,)) +val_update = (a, u, p, t; scale = 0.0) -> scale +α = ScalarOperator(0.0; update_func = val_update, accepted_kwargs = (:scale,)) β = 2 * α + 3 / α -# update L out-of-place, and evaluate -β(u, p, t; scale = 1.0) +# Update β and evaluate with the new interface +result = β(v, u, p, t; scale = 1.0) -# update L in-place and evaluate -β(v, u, p, t; scale = 1.0) +# In-place application +β(w, v, u, p, t; scale = 1.0) + +# In-place with scaling +w_orig = copy(w) +α_val = 2.0 +β_val = 0.5 +β(w, v, u, p, t, α_val, β_val; scale = 1.0) # w = α_val*(β*v) + β_val*w ``` """ function ScalarOperator(val; @@ -198,6 +205,24 @@ function SciMLOperators.update_coefficients(L::ScalarOperator, u, p, t; kwargs.. return ScalarOperator(L.update_func(L.val, u, p, t; kwargs...), L.update_func) end + +# Add ScalarOperator specific implementations for the new interface +function (α::ScalarOperator)(v::AbstractArray, u, p, t; kwargs...) + α = update_coefficients(α, u, p, t; kwargs...) + convert(Number, α) * v +end + +function (α::ScalarOperator)(w::AbstractArray, v::AbstractArray, u, p, t; kwargs...) + update_coefficients!(α, u, p, t; kwargs...) + mul!(w, α, v) +end + +function (α::ScalarOperator)(w::AbstractArray, v::AbstractArray, u, p, t, a, b; kwargs...) + update_coefficients!(α, u, p, t; kwargs...) + mul!(w, α, v, a, b) +end + + """ $TYPEDEF @@ -257,15 +282,37 @@ function Base.show(io::IO, α::AddedScalarOperator) end Base.conj(L::AddedScalarOperator) = AddedScalarOperator(conj.(L.ops)) -function update_coefficients(L::AddedScalarOperator, u, p, t) +function update_coefficients(L::AddedScalarOperator, u, p, t; kwargs...) ops = () for op in L.ops - ops = (ops..., update_coefficients(op, u, p, t)) + ops = (ops..., update_coefficients(op, u, p, t; kwargs...)) end - @reset L.ops = ops + AddedScalarOperator(ops) end +function update_coefficients!(L::AddedScalarOperator, u, p, t; kwargs...) + for op in L.ops + update_coefficients!(op, u, p, t; kwargs...) + end + nothing +end +function (α::AddedScalarOperator)(v::AbstractArray, u, p, t; kwargs...) + α = update_coefficients(α, u, p, t; kwargs...) + convert(Number, α) * v +end + +function (α::AddedScalarOperator)(w::AbstractArray, v::AbstractArray, u, p, t; kwargs...) + update_coefficients!(α, u, p, t; kwargs...) + mul!(w, α, v) +end + +function (α::AddedScalarOperator)(w::AbstractArray, v::AbstractArray, u, p, t, a, b; kwargs...) + update_coefficients!(α, u, p, t; kwargs...) + mul!(w, α, v, a, b) +end + + getops(α::AddedScalarOperator) = α.ops has_ldiv(α::AddedScalarOperator) = !iszero(convert(Number, α)) has_ldiv!(α::AddedScalarOperator) = has_ldiv(α) @@ -351,13 +398,33 @@ end Base.conj(L::ComposedScalarOperator) = ComposedScalarOperator(conj.(L.ops)) Base.:-(α::AbstractSciMLScalarOperator{T}) where {T} = (-one(T)) * α -function update_coefficients(L::ComposedScalarOperator, u, p, t) +function update_coefficients(L::ComposedScalarOperator, u, p, t; kwargs...) ops = () for op in L.ops - ops = (ops..., update_coefficients(op, u, p, t)) + ops = (ops..., update_coefficients(op, u, p, t; kwargs...)) end - @reset L.ops = ops + ComposedScalarOperator(ops) +end +function update_coefficients!(L::ComposedScalarOperator, u, p, t; kwargs...) + for op in L.ops + update_coefficients!(op, u, p, t; kwargs...) + end + nothing +end +function (α::ComposedScalarOperator)(v::AbstractArray, u, p, t; kwargs...) + α = update_coefficients(α, u, p, t; kwargs...) + convert(Number, α) * v +end + +function (α::ComposedScalarOperator)(w::AbstractArray, v::AbstractArray, u, p, t; kwargs...) + update_coefficients!(α, u, p, t; kwargs...) + mul!(w, α, v) +end + +function (α::ComposedScalarOperator)(w::AbstractArray, v::AbstractArray, u, p, t, a, b; kwargs...) + update_coefficients!(α, u, p, t; kwargs...) + mul!(w, α, v, a, b) end getops(α::ComposedScalarOperator) = α.ops @@ -413,11 +480,27 @@ function Base.show(io::IO, α::InvertedScalarOperator) end Base.conj(L::InvertedScalarOperator) = InvertedScalarOperator(conj(L.λ)) -function update_coefficients(L::InvertedScalarOperator, u, p, t) - @reset L.λ = update_coefficients(L.λ, u, p, t) - L +function update_coefficients(L::InvertedScalarOperator, u, p, t; kwargs...) + InvertedScalarOperator(update_coefficients(L.λ, u, p, t; kwargs...)) +end +function update_coefficients!(L::InvertedScalarOperator, u, p, t; kwargs...) + update_coefficients!(L.λ, u, p, t; kwargs...) + nothing +end +function (α::InvertedScalarOperator)(v::AbstractArray, u, p, t; kwargs...) + α = update_coefficients(α, u, p, t; kwargs...) + convert(Number, α) * v +end + +function (α::InvertedScalarOperator)(w::AbstractArray, v::AbstractArray, u, p, t; kwargs...) + update_coefficients!(α, u, p, t; kwargs...) + mul!(w, α, v) end +function (α::InvertedScalarOperator)(w::AbstractArray, v::AbstractArray, u, p, t, a, b; kwargs...) + update_coefficients!(α, u, p, t; kwargs...) + mul!(w, α, v, a, b) +end getops(α::InvertedScalarOperator) = (α.λ,) has_ldiv(α::InvertedScalarOperator) = has_mul(α.λ) has_ldiv!(α::InvertedScalarOperator) = has_ldiv(α) diff --git a/src/tensor.jl b/src/tensor.jl index 8ca4bc05..82978dba 100644 --- a/src/tensor.jl +++ b/src/tensor.jl @@ -10,9 +10,46 @@ product operator. TensorProductOperator(A, B) = A ⊗ B TensorProductOperator(A, B, C) = A ⊗ B ⊗ C -(A ⊗ B)(u) = vec(B * reshape(u, M, N) * transpose(A)) +(A ⊗ B)(v) = vec(B * reshape(v, M, N) * transpose(A)) ``` where `M = size(B, 2)`, and `N = size(A, 2)` + +# Example + +``` +using SciMLOperators, LinearAlgebra + +# Create basic operators +A = rand(3, 3) +B = rand(4, 4) +A_op = MatrixOperator(A) +B_op = MatrixOperator(B) + +# Create tensor product operator +T = A_op ⊗ B_op + +# Apply to a vector using the new interface +v = rand(3*4) # Action vector +u = rand(3*4) # Update vector +p = nothing +t = 0.0 + +# Out-of-place application +result = T(v, u, p, t) + +# For in-place operations, need to cache the operator first +T_cached = cache_operator(T, v) + +# In-place application +w = zeros(size(T, 1)) +T_cached(w, v, u, p, t) + +# In-place with scaling +w_orig = copy(w) +α = 2.0 +β = 0.5 +T_cached(w, v, u, p, t, α, β) # w = α*(T*v) + β*w_orig +``` """ """ @@ -126,88 +163,88 @@ has_ldiv!(L::TensorProductOperator) = reduce(&, has_ldiv!.(L.ops)) factorize(L::TensorProductOperator) = TensorProductOperator(factorize.(L.ops)...) # operator application -function Base.:*(L::TensorProductOperator, u::AbstractVecOrMat) +function Base.:*(L::TensorProductOperator, v::AbstractVecOrMat) outer, inner = L.ops _, ni = size(inner) _, no = size(outer) m, n = size(L) - k = size(u, 2) + k = size(v, 2) - U = reshape(u, (ni, no * k)) + U = reshape(v, (ni, no * k)) C = inner * U - V = outer_mul(L, u, C) + V = outer_mul(L, v, C) - u isa AbstractMatrix ? reshape(V, (m, k)) : reshape(V, (m,)) + v isa AbstractMatrix ? reshape(V, (m, k)) : reshape(V, (m,)) end -function Base.:\(L::TensorProductOperator, u::AbstractVecOrMat) +function Base.:\(L::TensorProductOperator, v::AbstractVecOrMat) outer, inner = L.ops mi, _ = size(inner) mo, _ = size(outer) m, n = size(L) - k = size(u, 2) + k = size(v, 2) - U = reshape(u, (mi, mo * k)) + U = reshape(v, (mi, mo * k)) C = inner \ U - V = outer_div(L, u, C) + V = outer_div(L, v, C) - u isa AbstractMatrix ? reshape(V, (n, k)) : reshape(V, (n,)) + v isa AbstractMatrix ? reshape(V, (n, k)) : reshape(V, (n,)) end -function cache_self(L::TensorProductOperator, u::AbstractVecOrMat) +function cache_self(L::TensorProductOperator, v::AbstractVecOrMat) outer, inner = L.ops mi, ni = size(inner) mo, no = size(outer) - k = size(u, 2) + k = size(v, 2) # 3 arg mul! - c1 = lmul!(false, similar(u, (mi, no * k))) # c1 = inner * u - c2 = lmul!(false, similar(u, (no, mi, k))) # permut (2, 1, 3) - c3 = lmul!(false, similar(u, (mo, mi * k))) # c3 = outer * c2 + c1 = lmul!(false, similar(v, (mi, no * k))) # c1 = inner * v + c2 = lmul!(false, similar(v, (no, mi, k))) # permut (2, 1, 3) + c3 = lmul!(false, similar(v, (mo, mi * k))) # c3 = outer * c2 # 5 arg mul! - c4 = lmul!(false, similar(u, (mo * mi, k))) # cache v in 5 arg mul! + c4 = lmul!(false, similar(v, (mo * mi, k))) # cache v in 5 arg mul! # 3 arg ldiv! if reduce(&, issquare.(L.ops)) c5, c6, c7 = c1, c2, c3 else - c5 = lmul!(false, similar(u, (ni, mo * k))) # c5 = inner \ u - c6 = lmul!(false, similar(u, (mo, ni, k))) # permut (2, 1, 3) - c7 = lmul!(false, similar(u, (no, ni * k))) # c7 = outer \ c6 + c5 = lmul!(false, similar(v, (ni, mo * k))) # c5 = inner \ v + c6 = lmul!(false, similar(v, (mo, ni, k))) # permut (2, 1, 3) + c7 = lmul!(false, similar(v, (no, ni * k))) # c7 = outer \ c6 end @reset L.cache = (c1, c2, c3, c4, c5, c6, c7) L end -function cache_internals(L::TensorProductOperator, u::AbstractVecOrMat) +function cache_internals(L::TensorProductOperator, v::AbstractVecOrMat) if !iscached(L) - L = cache_self(L, u) + L = cache_self(L, v) end outer, inner = L.ops mi, ni = size(inner) _, no = size(outer) - k = size(u, 2) + k = size(v, 2) - uinner = reshape(u, (ni, no * k)) - uouter = reshape(L.cache[2], (no, mi * k)) + vinner = reshape(v, (ni, no * k)) + vouter = reshape(L.cache[2], (no, mi * k)) - @reset L.ops[2] = cache_operator(inner, uinner) - @reset L.ops[1] = cache_operator(outer, uouter) + @reset L.ops[2] = cache_operator(inner, vinner) + @reset L.ops[1] = cache_operator(outer, vouter) L end -function LinearAlgebra.mul!(v::AbstractVecOrMat, +function LinearAlgebra.mul!(w::AbstractVecOrMat, L::TensorProductOperator, - u::AbstractVecOrMat) + v::AbstractVecOrMat) @assert iscached(L) """cache needs to be set up for operator of type $L. Set up cache by calling `cache_operator(L, u)`""" @@ -215,28 +252,28 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, _, ni = size(inner) _, no = size(outer) - k = size(u, 2) + k = size(v, 2) C1, C2, C3 = L.cache[1:3] - U = reshape(u, (ni, no * k)) + U = reshape(v, (ni, no * k)) - """ - v .= kron(B, A) * u + #= + v .= kron(B, A) * v V .= A * U * B' - """ + =# # C .= A * U mul!(C1, inner, U) # V .= U * B' <===> V' .= B * C' - outer_mul!(v, L, u) + outer_mul!(w, L, v) - v + w end -function LinearAlgebra.mul!(v::AbstractVecOrMat, +function LinearAlgebra.mul!(w::AbstractVecOrMat, L::TensorProductOperator, - u::AbstractVecOrMat, + v::AbstractVecOrMat, α, β) @assert iscached(L) """cache needs to be set up for operator of type @@ -246,10 +283,10 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, mi, ni = size(inner) mo, no = size(outer) - k = size(u, 2) + k = size(v, 2) C1 = first(L.cache) - U = reshape(u, (ni, no * k)) + U = reshape(v, (ni, no * k)) """ v .= α * kron(B, A) * u + β * v @@ -261,14 +298,14 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, # V = α(C * B') + β(V) c = reshape(C1, (mi * no, k)) - outer_mul!(v, L, c, α, β) + outer_mul!(w, L, c, α, β) - v + w end -function LinearAlgebra.ldiv!(v::AbstractVecOrMat, +function LinearAlgebra.ldiv!(w::AbstractVecOrMat, L::TensorProductOperator, - u::AbstractVecOrMat) + v::AbstractVecOrMat) @assert iscached(L) """cache needs to be set up for operator of type $L. Set up cache by calling `cache_operator(L, u)`""" @@ -276,13 +313,13 @@ function LinearAlgebra.ldiv!(v::AbstractVecOrMat, mi, ni = size(inner) mo, no = size(outer) - k = size(u, 2) + k = size(v, 2) C5 = L.cache[5] - U = reshape(u, (mi, mo * k)) + U = reshape(v, (mi, mo * k)) """ - v .= kron(B, A) ldiv u + v .= kron(B, A) ldiv v V .= (A ldiv U) / B' """ @@ -291,12 +328,12 @@ function LinearAlgebra.ldiv!(v::AbstractVecOrMat, # V .= C / B' <==> V' .= B \ C' c = reshape(C5, (ni * mo, k)) - outer_div!(v, L, c) + outer_div!(w, L, c) - v + w end -function LinearAlgebra.ldiv!(L::TensorProductOperator, u::AbstractVecOrMat) +function LinearAlgebra.ldiv!(L::TensorProductOperator, v::AbstractVecOrMat) outer, inner = L.ops msg = "Two-argument ldiv! is only available for square operators" @@ -305,13 +342,13 @@ function LinearAlgebra.ldiv!(L::TensorProductOperator, u::AbstractVecOrMat) @assert issquare(outer) msg @assert iscached(L) """cache needs to be set up for operator of type - $L. Set up cache by calling `cache_operator(L, u)`""" + $L. Set up cache by calling `cache_operator(L, v)`""" mi = size(inner, 1) mo = size(outer, 1) - k = size(u, 2) + k = size(v, 2) - U = reshape(u, (mi, mo * k)) + U = reshape(v, (mi, mo * k)) """ u .= kron(B, A) ldiv u @@ -322,24 +359,24 @@ function LinearAlgebra.ldiv!(L::TensorProductOperator, u::AbstractVecOrMat) ldiv!(inner, U) # U .= U / B' <==> U' .= B \ U' - outer_div!(L, u) + outer_div!(L, v) - u + v end # helper functions const PERM = (2, 1, 3) -function outer_mul(L::TensorProductOperator, u::AbstractVecOrMat, C::AbstractVecOrMat) +function outer_mul(L::TensorProductOperator, v::AbstractVecOrMat, C::AbstractVecOrMat) outer, inner = L.ops if outer isa IdentityOperator return C elseif outer isa ScaledOperator - return outer.λ * outer_mul(outer.L, u, C) + return outer.λ * outer_mul(outer.L, v, C) end - k = size(u, 2) + k = size(v, 2) if k == 1 return transpose(outer * transpose(C)) end @@ -359,30 +396,30 @@ function outer_mul(L::TensorProductOperator, u::AbstractVecOrMat, C::AbstractVec V end -function outer_mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::AbstractVecOrMat) +function outer_mul!(w::AbstractVecOrMat, L::TensorProductOperator, v::AbstractVecOrMat) outer, inner = L.ops C1 = first(L.cache) if outer isa IdentityOperator - copyto!(v, C1) - return v + copyto!(w, C1) + return w elseif outer isa ScaledOperator - outer_mul!(v, outer.L, u) - lmul!(outer.λ, v) - return v + outer_mul!(w, outer.L, v) + lmul!(outer.λ, w) + return w end mi, _ = size(inner) mo, no = size(outer) # m , n = size(L) - k = size(u, 2) + k = size(v, 2) if k == 1 - V = reshape(v, (mi, mo)) + W = reshape(w, (mi, mo)) C1 = reshape(C1, (mi, no)) - mul!(transpose(V), outer, transpose(C1)) - return v + mul!(transpose(W), outer, transpose(C1)) + return w end C2, C3 = L.cache[2:3] @@ -392,64 +429,64 @@ function outer_mul!(v::AbstractVecOrMat, L::TensorProductOperator, u::AbstractVe C2 = reshape(C2, (no, mi * k)) mul!(C3, outer, C2) C3 = reshape(C3, (mo, mi, k)) - V = reshape(v, (mi, mo, k)) - permutedims!(V, C3, PERM) + W = reshape(w, (mi, mo, k)) + permutedims!(W, C3, PERM) - v + w end -function outer_mul!(v::AbstractVecOrMat, L::TensorProductOperator, - c::AbstractVecOrMat, α, β) +function outer_mul!(w::AbstractVecOrMat, L::TensorProductOperator, + v::AbstractVecOrMat, α, β) outer, inner = L.ops m, _ = size(L) - k = size(c, 2) + k = size(v, 2) if outer isa IdentityOperator - c = reshape(c, (m, k)) - axpby!(α, c, β, v) - return v + v = reshape(v, (m, k)) + axpby!(α, v, β, w) + return w elseif outer isa ScaledOperator a = convert(Number, α * outer.λ) - outer_mul!(v, outer.L, c, a, β) - return v + outer_mul!(w, outer.L, v, a, β) + return w end mi, _ = size(inner) mo, no = size(outer) if k == 1 - V = reshape(v, (mi, mo)) - C = reshape(c, (mi, no)) - mul!(transpose(V), outer, transpose(C), α, β) - return v + W = reshape(w, (mi, mo)) + C = reshape(v, (mi, no)) + mul!(transpose(W), outer, transpose(C), α, β) + return w end C2, C3, c4 = L.cache[2:4] - C = reshape(c, (mi, no, k)) + C = reshape(v, (mi, no, k)) permutedims!(C2, C, PERM) C2 = reshape(C2, (no, mi * k)) mul!(C3, outer, C2) C3 = reshape(C3, (mo, mi, k)) - V = reshape(v, (mi, mo, k)) - copy!(c4, v) - permutedims!(V, C3, PERM) - axpby!(β, c4, α, v) + W = reshape(w, (mi, mo, k)) + copy!(c4, w) + permutedims!(W, C3, PERM) + axpby!(β, c4, α, w) - v + w end -function outer_div(L::TensorProductOperator, u::AbstractVecOrMat, C::AbstractVecOrMat) +function outer_div(L::TensorProductOperator, v::AbstractVecOrMat, C::AbstractVecOrMat) outer, inner = L.ops if outer isa IdentityOperator return C elseif outer isa ScaledOperator - return outer.λ \ outer_div(outer.L, u, C) + return outer.λ \ outer_div(outer.L, v, C) end - k = size(u, 2) + k = size(v, 2) if k == 1 return transpose(outer \ transpose(C)) end @@ -504,26 +541,26 @@ function outer_div!(v::AbstractVecOrMat, L::TensorProductOperator, c::AbstractVe v end -function outer_div!(L::TensorProductOperator, u::AbstractVecOrMat) +function outer_div!(L::TensorProductOperator, v::AbstractVecOrMat) outer, inner = L.ops if outer isa IdentityOperator - return u + return v elseif outer isa ScaledOperator - outer_div!(outer.L, u) - ldiv!(outer.λ, u) + outer_div!(outer.L, v) + ldiv!(outer.λ, v) return u end _, ni = size(inner) _, no = size(outer) - k = size(u, 2) + k = size(v, 2) - U = reshape(u, (ni, no * k)) + U = reshape(v, (ni, no * k)) if k == 1 ldiv!(outer, transpose(U)) - return u + return v end C = first(L.cache) @@ -536,6 +573,26 @@ function outer_div!(L::TensorProductOperator, u::AbstractVecOrMat) C = reshape(C, (no, ni, k)) permutedims!(U, C, PERM) - u + v +end + +# Out-of-place: v is action vector, u is update vector +function (L::TensorProductOperator)(v::AbstractVecOrMat, u, p, t; kwargs...) + L = update_coefficients(L, u, p, t; kwargs...) + L * v +end + +# In-place: w is destination, v is action vector, u is update vector +function (L::TensorProductOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + mul!(w, L, v) + return w +end + +# In-place with scaling: w = α*(L*v) + β*w +function (L::TensorProductOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...) + update_coefficients!(L, u, p, t; kwargs...) + mul!(w, L, v, α, β) + return w end # diff --git a/test/Project.toml b/test/Project.toml index a84a6e5b..8b415525 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/test/basic.jl b/test/basic.jl index 2f9de506..06dfd7bf 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -19,6 +19,10 @@ K = 12 @testset "IdentityOperator" begin A = rand(N, N) |> MatrixOperator u = rand(N, K) + v = rand(N, K) + w = zeros(N, K) + p = nothing + t = 0 α = rand() β = rand() Id = IdentityOperator(N) @@ -41,11 +45,29 @@ K = 12 @test op(Id, u) ≈ u end + # Test with new interface - same update and action vector + @test Id(u, u, p, t) ≈ u + + # Test with different vectors for update and action + @test Id(v, u, p, t) ≈ v + + # Test in-place operation + copy!(w, zeros(N, K)) + Id(w, v, u, p, t) + @test w ≈ v + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + Id(w, v, u, p, t, α, β) + @test w ≈ α * v + β * orig_w + + # Original tests v = rand(N, K) @test mul!(v, Id, u) ≈ u v = rand(N, K) - w = copy(v) - @test mul!(v, Id, u, α, β) ≈ α * (I * u) + β * w + w_orig = copy(v) + @test mul!(v, Id, u, α, β) ≈ α * (I * u) + β * w_orig v = rand(N, K) @test ldiv!(v, Id, u) ≈ u @@ -60,7 +82,11 @@ end @testset "NullOperator" begin A = rand(N, N) |> MatrixOperator - u = rand(N, K) + u = rand(N, K) + v = rand(N, K) + w = zeros(N, K) + p = nothing + t = 0 α = rand() β = rand() Z = NullOperator(N) @@ -81,11 +107,29 @@ end @test Z * u ≈ zero(u) + # Test with new interface - same update and action vector + @test Z(u, u, p, t) ≈ zero(u) + + # Test with different vectors for update and action + @test Z(v, u, p, t) ≈ zero(v) + + # Test in-place operation + copy!(w, ones(N, K)) + Z(w, v, u, p, t) + @test w ≈ zero(v) + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + Z(w, v, u, p, t, α, β) + @test w ≈ β * orig_w + + # Original tests v = rand(N, K) @test mul!(v, Z, u) ≈ zero(u) v = rand(N, K) - w = copy(v) - @test mul!(v, Z, u, α, β) ≈ α * (0 * u) + β * w + w_orig = copy(v) + @test mul!(v, Z, u, α, β) ≈ α * (0 * u) + β * w_orig for op in (*, ∘) @test op(Z, A) isa NullOperator @@ -100,7 +144,11 @@ end @testset "ScaledOperator" begin A = rand(N, N) D = Diagonal(rand(N)) - u = rand(N, K) + u = rand(N, K) # Update vector + v = rand(N, K) # Action vector + w = zeros(N, K) # Output vector + p = nothing + t = 0 α = rand() β = rand() a = rand() @@ -114,8 +162,25 @@ end @test issquare(op) @test islinear(op) - @test α * A * u ≈ op * u - @test (β * op) * u ≈ β * α * A * u + @test α * A * v ≈ op * v + @test (β * op) * v ≈ β * α * A * v + + # Test with new interface - same vector for update and action + @test op(u, u, p, t) ≈ α * A * u + + # Test with different vectors for update and action + @test op(v, u, p, t) ≈ α * A * v + + # Test in-place operation + copy!(w, zeros(N, K)) + op(w, v, u, p, t) + @test w ≈ α * A * v + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + op(w, v, u, p, t, a, b) + @test w ≈ a * (α * A * v) + b * orig_w opF = factorize(op) @@ -125,33 +190,30 @@ end @test α * A ≈ convert(AbstractMatrix, op) ≈ convert(AbstractMatrix, opF) - v = rand(N, K) - @test mul!(v, op, u) ≈ α * A * u - v = rand(N, K) - w = copy(v) - @test mul!(v, op, u, a, b) ≈ a * (α * A * u) + b * w + w = rand(N, K) + @test mul!(w, op, v) ≈ α * A * v + w = rand(N, K) + w_orig = copy(w) + @test mul!(w, op, v, a, b) ≈ a * (α * A * v) + b * w_orig op = ScaledOperator(α, MatrixOperator(D)) - v = rand(N, K) - @test ldiv!(v, op, u) ≈ (α * D) \ u - v = copy(u) - @test ldiv!(op, u) ≈ (α * D) \ v -end - -function apply_op!(H, du, u, p, t) - H(du, u, p, t) - return nothing + w = rand(N, K) + @test ldiv!(v, op, w) ≈ (α * D) \ w + w = copy(v) + @test ldiv!(op, w) ≈ (α * D) \ v end -test_apply_noalloc(H, du, u, p, t) = @test (@allocations apply_op!(H, du, u, p, t)) == 0 - @testset "AddedOperator" begin A = rand(N, N) |> MatrixOperator B = rand(N, N) |> MatrixOperator C = rand(N, N) |> MatrixOperator α = rand() β = rand() - u = rand(N, K) + u = rand(N, K) # Update vector + v = rand(N, K) # Action vector + w = zeros(N, K) # Output vector + p = nothing + t = 0 for op in (+, -) op1 = op(A, B) @@ -173,6 +235,29 @@ test_apply_noalloc(H, du, u, p, t) = @test (@allocations apply_op!(H, du, u, p, @test op2 * u ≈ op(α * A * u, B * u) @test op3 * u ≈ op(A * u, β * B * u) @test op4 * u ≈ op(α * A * u, β * B * u) + + # Test new interface - combined case + @test op1(u, u, p, t) ≈ op(A * u, B * u) + @test op2(u, u, p, t) ≈ op(α * A * u, B * u) + @test op3(u, u, p, t) ≈ op(A * u, β * B * u) + @test op4(u, u, p, t) ≈ op(α * A * u, β * B * u) + + # Test new interface - separate vectors + @test op1(v, u, p, t) ≈ op(A * v, B * v) + @test op2(v, u, p, t) ≈ op(α * A * v, B * v) + @test op3(v, u, p, t) ≈ op(A * v, β * B * v) + @test op4(v, u, p, t) ≈ op(α * A * v, β * B * v) + + # Test in-place operation + copy!(w, zeros(N, K)) + op1(w, v, u, p, t) + @test w ≈ op(A * v, B * v) + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + op1(w, v, u, p, t, α, β) + @test w ≈ α * op(A * v, B * v) + β * orig_w end op = AddedOperator(A, B) @@ -181,8 +266,8 @@ test_apply_noalloc(H, du, u, p, t) = @test (@allocations apply_op!(H, du, u, p, v = rand(N, K) @test mul!(v, op, u) ≈ (A + B) * u v = rand(N, K) - w = copy(v) - @test mul!(v, op, u, α, β) ≈ α * (A + B) * u + β * w + w_orig = copy(v) + @test mul!(v, op, u, α, β) ≈ α * (A + B) * u + β * w_orig # ensure AddedOperator doesn't nest A = MatrixOperator(rand(N, N)) @@ -192,11 +277,6 @@ test_apply_noalloc(H, du, u, p, t) = @test (@allocations apply_op!(H, du, u, p, @test !isa(op, AddedOperator) end - # Allocations Tests - - @allocations apply_op!(op, v, u, (), 1.0) # warmup - test_apply_noalloc(op, v, u, (), 1.0) - ## Time-Dependent Coefficients for T in (Float32, Float64, ComplexF32, ComplexF64) @@ -221,14 +301,10 @@ test_apply_noalloc(H, du, u, p, t) = @test (@allocations apply_op!(H, du, u, p, H_dense = c1 * A1_dense + c2 * A2_dense + c3 * A3_dense u = rand(T, N) + v = rand(T, N) du = similar(u) p = (ω = 0.1,) t = 0.1 - - @allocations apply_op!(H_sparse, du, u, p, t) # warmup - @allocations apply_op!(H_dense, du, u, p, t) # warmup - test_apply_noalloc(H_sparse, du, u, p, t) - test_apply_noalloc(H_dense, du, u, p, t) end end @@ -236,7 +312,11 @@ end A = rand(N, N) B = rand(N, N) C = rand(N, N) - u = rand(N, K) + u = rand(N, K) # Update vector + v = rand(N, K) # Action vector + w = zeros(N, K) # Output vector + p = nothing + t = 0 α = rand() β = rand() @@ -261,16 +341,34 @@ end @test ABCmulu ≈ op * u @test ABCdivu ≈ op \ u ≈ opF \ u + + # Test new interface - combined case + @test op(u, u, p, t) ≈ ABCmulu + + # Test new interface - separate vectors + @test op(v, u, p, t) ≈ (A * B * C) * v @test !iscached(op) op = cache_operator(op, u) @test iscached(op) - + + # Test in-place operation with new interface + copy!(w, zeros(N, K)) + op(w, v, u, p, t) + @test w ≈ (A * B * C) * v + + # Test in-place with scaling with new interface + copy!(w, rand(N, K)) + orig_w = copy(w) + op(w, v, u, p, t, α, β) + @test w ≈ α * ((A * B * C) * v) + β * orig_w + + # Original tests v = rand(N, K) @test mul!(v, op, u) ≈ ABCmulu v = rand(N, K) - w = copy(v) - @test mul!(v, op, u, α, β) ≈ α * ABCmulu + β * w + w_orig = copy(v) + @test mul!(v, op, u, α, β) ≈ α * ABCmulu + β * w_orig A = rand(N) |> Diagonal B = rand(N) |> Diagonal @@ -311,7 +409,11 @@ end (transpose, TransposedOperator, AbstractTransposedVecOrMat)) A = rand(N, N) D = Bidiagonal(rand(N, N), :L) - u = rand(N, K) + u = rand(N, K) # Update vector + v = rand(N, K) # Action vector + w = zeros(N, K) # Output vector + p = nothing + t = 0 α = rand() β = rand() a = rand() @@ -341,11 +443,14 @@ end @test op(u) * AAt ≈ op(A * u) @test op(u) / AAt ≈ op(A \ u) + # Not implementing separate test for adjoint/transpose operators + # since they typically rely on the base operator implementations + v = rand(N, K) @test mul!(op(v), op(u), AAt) ≈ op(A * u) v = rand(N, K) - w = copy(v) - @test mul!(op(v), op(u), AAt, α, β) ≈ α * op(A * u) + β * op(w) + w_orig = copy(v) + @test mul!(op(v), op(u), AAt, α, β) ≈ α * op(A * u) + β * op(w_orig) v = rand(N, K) @test ldiv!(op(v), op(u), DDt) ≈ op(D \ u) @@ -358,7 +463,11 @@ end s = rand(N) D = Diagonal(s) |> MatrixOperator Di = InvertedOperator(D) - u = rand(N) + u = rand(N) # Update vector + v = rand(N) # Action vector + w = zeros(N) # Output vector + p = nothing + t = 0 α = rand() β = rand() @@ -371,16 +480,34 @@ end @test islinear(Di) @test Di * u ≈ u ./ s + + # Test new interface - same vectors + @test Di(u, u, p, t) ≈ u ./ s + + # Test new interface - separate vectors + @test Di(v, u, p, t) ≈ v ./ s + + # Test in-place operation + copy!(w, zeros(N)) + Di(w, v, u, p, t) + @test w ≈ v ./ s + + # Test in-place with scaling + copy!(w, rand(N)) + orig_w = copy(w) + Di(w, v, u, p, t, α, β) + @test w ≈ α * (v ./ s) + β * orig_w + + # Original tests v = rand(N) @test mul!(v, Di, u) ≈ u ./ s v = rand(N) - w = copy(v) - @test mul!(v, Di, u, α, β) ≈ α * (u ./ s) + β * w + w_orig = copy(v) + @test mul!(v, Di, u, α, β) ≈ α * (u ./ s) + β * w_orig @test Di \ u ≈ u .* s v = rand(N) @test ldiv!(v, Di, u) ≈ u .* s v = copy(u) @test ldiv!(Di, u) ≈ v .* s -end -# +end \ No newline at end of file diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml new file mode 100644 index 00000000..99beb132 --- /dev/null +++ b/test/downstream/Project.toml @@ -0,0 +1,5 @@ +[deps] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" + +[compat] +AllocCheck = "0.2.2" \ No newline at end of file diff --git a/test/downstream/alloccheck.jl b/test/downstream/alloccheck.jl new file mode 100644 index 00000000..5a6b125c --- /dev/null +++ b/test/downstream/alloccheck.jl @@ -0,0 +1,62 @@ +using SciMLOperators, AllocCheck, Random, SparseArrays, Test +using SciMLOperators: IdentityOperator, + NullOperator, + ScaledOperator, + AddedOperator +Random.seed!(0) +N = 8 +K = 12 +A = rand(N, N) |> MatrixOperator +B = rand(N, N) |> MatrixOperator +C = rand(N, N) |> MatrixOperator +α = rand() +β = rand() +u = rand(N, K) # Update vector +v = rand(N, K) # Action vector +w = zeros(N, K) # Output vector +p = () +t = 0 +op = AddedOperator(A, B) + +# Define a function to test allocations with the new interface +@check_allocs ignore_throw = true function apply_op!(H, w, v, u, p, t) + H(w, v, u, p, t) + return nothing +end + +if VERSION >= v"1.12-beta" + apply_op!(op, w, v, u, p, t) +else + @test_throws AllocCheckFailure apply_op!(op, w, v, u, p, t) +end + +for T in (Float32, Float64, ComplexF32, ComplexF64) + N = 100 + A1_sparse = MatrixOperator(sprand(T, N, N, 5 / N)) + A2_sparse = MatrixOperator(sprand(T, N, N, 5 / N)) + A3_sparse = MatrixOperator(sprand(T, N, N, 5 / N)) + + A1_dense = MatrixOperator(rand(T, N, N)) + A2_dense = MatrixOperator(rand(T, N, N)) + A3_dense = MatrixOperator(rand(T, N, N)) + + coeff1(a, u, p, t) = sin(p.ω * t) + coeff2(a, u, p, t) = cos(p.ω * t) + coeff3(a, u, p, t) = sin(p.ω * t) * cos(p.ω * t) + + c1 = ScalarOperator(rand(T), coeff1) + c2 = ScalarOperator(rand(T), coeff2) + c3 = ScalarOperator(rand(T), coeff3) + + H_sparse = c1 * A1_sparse + c2 * A2_sparse + c3 * A3_sparse + H_dense = c1 * A1_dense + c2 * A2_dense + c3 * A3_dense + + u = rand(T, N) + v = rand(T, N) + w = similar(u) + p = (ω = 0.1,) + t = 0.1 + + @test_throws AllocCheckFailure apply_op!(H_sparse, w, v, u, p, t) + @test_throws AllocCheckFailure apply_op!(H_dense, w, v, u, p, t) +end diff --git a/test/func.jl b/test/func.jl index cdc1a52c..807e078b 100644 --- a/test/func.jl +++ b/test/func.jl @@ -1,6 +1,6 @@ -# using SciMLOperators, LinearAlgebra using Random +using Test using SciMLOperators: ⊗ @@ -13,6 +13,7 @@ NK = N * K N1, N2, N3 = 3, 4, 5 M1, M2, M3 = 4, 5, 6 + u = nothing p = nothing t = 0.0 α = rand() @@ -24,93 +25,111 @@ NK = N * K M = prod(sz_out) A = rand(M, N) - u = rand(sz_in...) - v = rand(sz_out...) + u = nothing + v = rand(sz_in...) # action vector + w = rand(sz_out...) # output vector for in-place tests - _mul(A, u) = reshape(A * vec(u), sz_out) - f(u, p, t) = _mul(A, u) - f(du, u, p, t) = (mul!(vec(du), A, vec(u)); du) + _mul(A, v) = reshape(A * vec(v), sz_out) + f(v, u, p, t) = _mul(A, v) + f(w, v, u, p, t) = (mul!(vec(w), A, vec(v)); w) kw = (;) # FunctionOp kwargs if sz_in == sz_out F = lu(A) _div(A, v) = reshape(A \ vec(v), sz_in) - fi(u, p, t) = _div(A, u) - fi(du, u, p, t) = (ldiv!(vec(du), F, vec(u)); du) + fi(v, u, p, t) = _div(A, v) + fi(w, v, u, p, t) = (ldiv!(vec(w), F, vec(v)); w) kw = (; op_inverse = fi) end - L = FunctionOperator(f, u, v; kw...) - L = cache_operator(L, u) - - # test with ND-arrays - @test _mul(A, u) ≈ L(u, p, t) ≈ L * u ≈ mul!(zero(v), L, u) - @test α * _mul(A, u) + β * v ≈ mul!(copy(v), L, u, α, β) - + L = FunctionOperator(f, v, w; kw...) + L = cache_operator(L, v) + + # test with ND-arrays and new interface + @test _mul(A, v) ≈ L(v, u, p, t) ≈ L * v ≈ mul!(zero(w), L, v) + @test α * _mul(A, v) + β * w ≈ mul!(copy(w), L, v, α, β) + + # Test with different update and action vectors + action_vec = rand(sz_in...) + @test _mul(A, action_vec) ≈ L(action_vec, u, p, t) + if sz_in == sz_out - @test _div(A, v) ≈ L \ v ≈ ldiv!(zero(u), L, v) ≈ ldiv!(L, copy(v)) + @test _div(A, w) ≈ L \ w ≈ ldiv!(zero(v), L, w) ≈ ldiv!(L, copy(w)) end - + # test with vec(Array) - @test vec(_mul(A, u)) ≈ L(vec(u), p, t) ≈ L * vec(u) ≈ mul!(vec(zero(v)), L, vec(u)) - @test vec(α * _mul(A, u) + β * v) ≈ mul!(vec(copy(v)), L, vec(u), α, β) + @test vec(_mul(A, v)) ≈ L(vec(v), u, p, t) ≈ L * vec(v) ≈ mul!(vec(zero(w)), L, vec(v)) + @test vec(α * _mul(A, v) + β * w) ≈ mul!(vec(copy(w)), L, vec(v), α, β) if sz_in == sz_out - @test vec(_div(A, v)) ≈ L \ vec(v) ≈ ldiv!(vec(zero(u)), L, vec(v)) ≈ - ldiv!(L, vec(copy(v))) + @test vec(_div(A, w)) ≈ L \ vec(w) ≈ ldiv!(vec(zero(v)), L, vec(w)) ≈ + ldiv!(L, vec(copy(w))) end - @test_throws DimensionMismatch mul!(vec(v), L, u) - @test_throws DimensionMismatch mul!(v, L, vec(u)) - end # for + # Test in-place with different update and action vectors + output_vec = zeros(sz_out...) + L(output_vec, action_vec, u, p, t) + @test output_vec ≈ _mul(A, action_vec) + + # Test in-place with scaling + output_vec = rand(sz_out...) + orig_output = copy(output_vec) + L(output_vec, action_vec, u, p, t, α, β) + @test output_vec ≈ α * _mul(A, action_vec) + β * orig_output + + @test_throws DimensionMismatch mul!(vec(w), L, v) + @test_throws DimensionMismatch mul!(w, L, vec(v)) + end end @testset "(Unbatched) FunctionOperator" begin - u = rand(N, K) + v = rand(N, K) # action vector + w = zeros(N, K) # Output vector + u = nothing p = nothing t = 0.0 α = rand() β = rand() - _mul(A, u) = reshape(A * vec(u), N, K) - _div(A, u) = reshape(A \ vec(u), N, K) + _mul(A, v) = reshape(A * vec(v), N, K) + _div(A, v) = reshape(A \ vec(v), N, K) A = rand(NK, NK) |> Symmetric F = lu(A) Ai = inv(A) - f1(u, p, t) = _mul(A, u) - f1i(u, p, t) = _div(A, u) + f1(v, u, p, t) = _mul(A, v) + f1i(v, u, p, t) = _div(A, v) - f2(du, u, p, t) = (mul!(vec(du), A, vec(u)); du) - f2(du, u, p, t, α, β) = (mul!(vec(du), A, vec(u), α, β); du) - f2i(du, u, p, t) = (ldiv!(vec(du), F, vec(u)); du) - f2i(du, u, p, t, α, β) = (mul!(vec(du), Ai, vec(u), α, β); du) + f2(w, v, u, p, t) = (mul!(vec(w), A, vec(v)); w) + f2(w, v, u, p, t, α, β) = (mul!(vec(w), A, vec(v), α, β); w) + f2i(w, v, u, p, t) = (ldiv!(vec(w), F, vec(v)); w) + f2i(w, v, u, p, t, α, β) = (mul!(vec(w), Ai, vec(v), α, β); w) # out of place - op1 = FunctionOperator(f1, u; op_inverse = f1i, ifcache = false, islinear = true, + op1 = FunctionOperator(f1, v; op_inverse = f1i, ifcache = false, islinear = true, opnorm = true, issymmetric = true, ishermitian = true, isposdef = true) # in place - op2 = FunctionOperator(f2, u; op_inverse = f2i, ifcache = false, islinear = true, + op2 = FunctionOperator(f2, v, w; op_inverse = f2i, ifcache = false, islinear = true, opnorm = true, issymmetric = true, ishermitian = true, isposdef = true) + # Test traits @test issquare(op1) @test issquare(op2) - @test islinear(op1) @test islinear(op2) - @test op1' === op1 - + + # Test operator properties @test size(op1) == (NK, NK) @test has_adjoint(op1) @test has_mul(op1) @@ -127,37 +146,62 @@ end @test !iscached(op1) @test !iscached(op2) - @test !op1.traits.has_mul5 @test op2.traits.has_mul5 # 5-arg mul! (w/o cache) v = rand(N, K) w = copy(v) - @test α * _mul(A, u) + β * w ≈ mul!(v, op2, u, α, β) + @test α * _mul(A, v) + β * w ≈ mul!(w, op2, v, α, β) + + # Create test vectors for new interface + action_vec = rand(N, K) # Action vector + result_vec = zeros(N, K) # Result vector - op1 = cache_operator(op1, u) - op2 = cache_operator(op2, u) + # Cache operators + op1 = cache_operator(op1, v) + op2 = cache_operator(op2, v) @test iscached(op1) @test iscached(op2) - v = rand(N, K) - @test _mul(A, u) ≈ op1 * u ≈ mul!(v, op2, u) - v = rand(N, K) - @test _mul(A, u) ≈ op1(u, p, t) ≈ op2(v, u, p, t) + # Test standard operator operations (from original test) + w = rand(N, K) + @test _mul(A, v) ≈ op1 * v ≈ mul!(w, op2, v) + w = rand(N, K) + @test _mul(A, v) ≈ op1(v, u, p, t) ≈ op2(v, u, p, t) v = rand(N, K) w = copy(v) - @test α * _mul(A, u) + β * w ≈ mul!(v, op2, u, α, β) + @test α * _mul(A, v) + β * w ≈ mul!(w, op2, v, α, β) - v = rand(N, K) - @test _div(A, u) ≈ op1 \ u ≈ ldiv!(v, op2, u) - v = copy(u) - @test _div(A, v) ≈ ldiv!(op2, u) + w = rand(N, K) + @test _div(A, w) ≈ op1 \ w ≈ ldiv!(v, op2, w) + w = copy(v) + @test _div(A, w) ≈ ldiv!(op2, w) + + # Test with new interface - out of place + @test _mul(A, action_vec) ≈ op1(action_vec, u, p, t) + + # Test with new interface - in place + op2(result_vec, action_vec, u, p, t) + @test result_vec ≈ _mul(A, action_vec) + + # Test in-place with scaling + result_vec = rand(N, K) + orig_result = copy(result_vec) + op2(result_vec, action_vec, u, p, t, α, β) + @test result_vec ≈ α * _mul(A, action_vec) + β * orig_result + + # Test inverse operations with new interface + inv_result = zeros(N, K) + @test _div(A, action_vec) ≈ op1 \ action_vec + ldiv!(inv_result, op2, action_vec) + @test inv_result ≈ _div(A, action_vec) end @testset "Batched FunctionOperator" begin - u = rand(N, K) + v = rand(N, K) + u = nothing p = nothing t = 0.0 α = rand() @@ -167,16 +211,16 @@ end F = lu(A) Ai = inv(A) - f1(u, p, t) = A * u - f1i(u, p, t) = A \ u + f1(v, u, p, t) = A * v + f1i(v, u, p, t) = A \ v - f2(du, u, p, t) = mul!(du, A, u) - f2(du, u, p, t, α, β) = mul!(du, A, u, α, β) - f2i(du, u, p, t) = ldiv!(du, F, u) - f2i(du, u, p, t, α, β) = mul!(du, Ai, u, α, β) + f2(w, v, u, p, t) = mul!(w, A, v) + f2(w, v, u, p, t, α, β) = mul!(w, A, v, α, β) + f2i(w, v, u, p, t) = ldiv!(w, F, v) + f2i(w, v, u, p, t, α, β) = mul!(w, Ai, v, α, β) # out of place - op1 = FunctionOperator(f1, u, A * u; op_inverse = f1i, ifcache = false, + op1 = FunctionOperator(f1, v, A * v; op_inverse = f1i, ifcache = false, batch = true, islinear = true, opnorm = true, @@ -185,7 +229,7 @@ end isposdef = true) # in place - op2 = FunctionOperator(f2, u, A * u; op_inverse = f2i, ifcache = false, + op2 = FunctionOperator(f2, v, A * v; op_inverse = f2i, ifcache = false, batch = true, islinear = true, opnorm = true, @@ -224,140 +268,141 @@ end # 5-arg mul! (w/o cache) v = rand(N, K) w = copy(v) - @test α * *(A, u) + β * w ≈ mul!(v, op2, u, α, β) + @test α * *(A, v) + β * w ≈ mul!(w, op2, v, α, β) - op1 = cache_operator(op1, u) - op2 = cache_operator(op2, u) + op1 = cache_operator(op1, v) + op2 = cache_operator(op2, v) @test iscached(op1) @test iscached(op2) v = rand(N, K) - @test *(A, u) ≈ op1 * u ≈ mul!(v, op2, u) + @test *(A, v) ≈ op1 * v ≈ mul!(w, op2, v) + + # Test with new interface v = rand(N, K) - @test *(A, u) ≈ op1(u, p, t) ≈ op2(v, u, p, t) + @test *(A, v) ≈ op1(w, v, u, p, t) ≈ op2(w, v, u, p, t) + v = rand(N, K) w = copy(v) - @test α * *(A, u) + β * w ≈ mul!(v, op2, u, α, β) + @test α * *(A, v) + β * w ≈ mul!(w, op2, v, α, β) - v = rand(N, K) - @test \(A, u) ≈ op1 \ u ≈ ldiv!(v, op2, u) - v = copy(u) - @test \(A, v) ≈ ldiv!(op2, u) + # Test old style calls + w = rand(N, K) + @test \(A, w) ≈ op1 \ w ≈ ldiv!(v, op2, w) + w = copy(v) + @test \(A, w) ≈ ldiv!(op2, w) + + # Test new interface ldiv + w = rand(N, K) + ldiv_result = zeros(N, K) + ldiv!(ldiv_result, op2, w) + @test ldiv_result ≈ A \ w end @testset "FunctionOperator update test" begin - u = rand(N, K) + u = rand(N, N) # Update vector + v = rand(N, K) # Action vector + w = zeros(N, K) # Result vector p = rand(N) t = rand() scale = rand() # Accept a kwarg "scale" in operator action - f(du, u, p, t; scale = 1.0) = mul!(du, Diagonal(p * t * scale), u) - f(u, p, t; scale = 1.0) = Diagonal(p * t * scale) * u + f(w, v, u, p, t; scale = 1.0) = mul!(w, Diagonal(u * p * t * scale), v) + f(v, u, p, t; scale = 1.0) = Diagonal(u * p * t * scale) * v + # Test with both tuple and Val forms of accepted_kwargs for acc_kw in ((:scale,), Val((:scale,))) - L = FunctionOperator(f, u, u; p = zero(p), t = zero(t), batch = true, - accepted_kwargs = acc_kw, scale = 1.0) + # Function operator with keyword arguments + L = FunctionOperator(f, v, w; + u = u, + p = zero(p), + t = zero(t), + batch = true, + accepted_kwargs = acc_kw, + scale = 1.0) @test_throws ArgumentError FunctionOperator( - f, u, u; p = zero(p), t = zero(t), batch = true, + f, v, w; u = u, p = zero(p), t = zero(t), batch = true, accepted_kwargs = acc_kw) @test size(L) == (N, N) - ans = @. u * p * t * scale - @test L(u, p, t; scale) ≈ ans - v = copy(u) - @test L(v, u, p, t; scale) ≈ ans - - # test that output isn't accidentally mutated by passing an internal cache. - - A = Diagonal(p * t * scale) - u1 = rand(N, K) - u2 = rand(N, K) - - v1 = L * u1 - @test v1 ≈ A * u1 - v2 = L * u2 - @test v2 ≈ A * u2 - @test v1 ≈ A * u1 - @test v1 + v2 ≈ A * (u1 + u2) - + # Expected result with scaling + A = Diagonal(u * p * t * scale) + expected = A * v + ans = u * p .* t .* scale + + # Test with new interface + @test L(v, u, p, t; scale) ≈ expected + + # Test in-place with new interface + copy!(w, zeros(N, K)) + L(w, v, u, p, t; scale) + @test w ≈ expected + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + α_val = rand() + β_val = rand() + L(w, v, u, p, t, α_val, β_val; scale) + @test w ≈ α_val * expected + β_val * orig_w + + # Test that outputs aren't accidentally mutated + v1 = rand(N, K) + v2 = rand(N, K) + w1 = rand(N, K) + w2 = rand(N, K) + + # Expected results with different vectors + result1 = A * v1 + result2 = A * v2 + + # Test output consistency + w1 = zeros(N, K) + w2 = zeros(N, K) + + L(w1, v1, u, p, t; scale) + L(w2, v2, u, p, t; scale) + + @test w1 ≈ result1 + @test w2 ≈ result2 + + # Test matrix-vector multiplication + w1 = L * v1 + @test w1 ≈ A * v1 + w2 = L * v2 + @test w2 ≈ A * v2 + @test w1 ≈ A * v1 # Check v1 hasn't changed + @test w1 + w2 ≈ A * (v1 + v2) + + # Test in-place matrix-vector multiplication v1 .= 0.0 v2 .= 0.0 - - mul!(v1, L, u1) - @test v1 ≈ A * u1 - mul!(v2, L, u2) - @test v2 ≈ A * u2 - @test v1 ≈ A * u1 - @test v1 + v2 ≈ A * (u1 + u2) - + + mul!(w1, L, v1) + @test w1 ≈ A * v1 + mul!(w2, L, v2) + @test w2 ≈ A * v2 + @test w1 ≈ A * v1 + @test w1 + w2 ≈ A * (v1 + v2) + + # Test scaling v1 = rand(N, K) w1 = copy(v1) v2 = rand(N, K) w2 = copy(v2) a1, a2, b1, b2 = rand(4) - - mul!(v1, L, u1, a1, b1) - @test v1 ≈ a1 * A * u1 + b1 * w1 - mul!(v2, L, u2, a2, b2) - @test v2 ≈ a2 * A * u2 + b2 * w2 - @test v1 ≈ a1 * A * u1 + b1 * w1 - @test v1 + v2 ≈ (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2) + + res = copy(w1) + mul!(res, L, v1, a1, b1) + @test res ≈ a1 * A * v1 + b1 * w1 + res2 = copy(w2) + mul!(res2, L, v2, a2, b2) + @test res2 ≈ a2 * A * v2 + b2 * w2 + @test res ≈ a1 * A * v1 + b1 * w1 + @test res + res2 ≈ (a1 * A * v1 + b1 * w1) + (a2 * A * v2 + b2 * w2) end - - ## Do the same with Val((:scale,)) - - L = FunctionOperator(f, u, u; p = zero(p), t = zero(t), batch = true, - accepted_kwargs = Val((:scale,)), scale = 1.0) - - @test_throws ArgumentError FunctionOperator( - f, u, u; p = zero(p), t = zero(t), batch = true, - accepted_kwargs = Val((:scale,))) - - @test size(L) == (N, N) - - ans = @. u * p * t * scale - @test L(u, p, t; scale) ≈ ans - v = copy(u) - @test L(v, u, p, t; scale) ≈ ans - - # test that output isn't accidentally mutated by passing an internal cache. - - A = Diagonal(p * t * scale) - u1 = rand(N, K) - u2 = rand(N, K) - - v1 = L * u1 - @test v1 ≈ A * u1 - v2 = L * u2 - @test v2 ≈ A * u2 - @test v1 ≈ A * u1 - @test v1 + v2 ≈ A * (u1 + u2) - - v1 .= 0.0 - v2 .= 0.0 - - mul!(v1, L, u1) - @test v1 ≈ A * u1 - mul!(v2, L, u2) - @test v2 ≈ A * u2 - @test v1 ≈ A * u1 - @test v1 + v2 ≈ A * (u1 + u2) - - v1 = rand(N, K) - w1 = copy(v1) - v2 = rand(N, K) - w2 = copy(v2) - a1, a2, b1, b2 = rand(4) - - mul!(v1, L, u1, a1, b1) - @test v1 ≈ a1 * A * u1 + b1 * w1 - mul!(v2, L, u2, a2, b2) - @test v2 ≈ a2 * A * u2 + b2 * w2 - @test v1 ≈ a1 * A * u1 + b1 * w1 - @test v1 + v2 ≈ (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2) -end -# +end \ No newline at end of file diff --git a/test/matrix.jl b/test/matrix.jl index b971edb1..78f058c4 100644 --- a/test/matrix.jl +++ b/test/matrix.jl @@ -1,5 +1,6 @@ using SciMLOperators, LinearAlgebra using Random +using Test using SciMLOperators: InvertibleOperator, InvertedOperator, ⊗, AbstractSciMLOperator using FFTW @@ -9,7 +10,11 @@ N = 8 K = 19 @testset "MatrixOperator, InvertibleOperator" begin - u = rand(N, K) + # Vectors for testing + u = rand(N, K) # Both update and action vector + v = rand(N, K) # Output/action vector + w = zeros(N, K) # Output vector + p = nothing t = 0 α = rand() @@ -50,21 +55,35 @@ K = 19 @test A ≈ Matrix(AA) ≈ Matrix(FF) @test At ≈ Matrix(AAt) ≈ Matrix(FFt) - @test A * u ≈ AA(u, p, t) - @test At * u ≈ AAt(u, p, t) + # Test with new interface - same vector for update and action + @test A * u ≈ AA(u, u, p, t) + @test At * u ≈ AAt(u, u, p, t) + + # Test with different vectors for update and action + @test A * v ≈ AA(v, u, p, t) + @test At * v ≈ AAt(v, u, p, t) @test A \ u ≈ AA \ u ≈ FF \ u @test At \ u ≈ AAt \ u ≈ FFt \ u - v = rand(N, K) - @test mul!(v, AA, u) ≈ A * u - v = rand(N, K) - w = copy(v) - @test mul!(v, AA, u, α, β) ≈ α * A * u + β * w + # Test in-place operations + copy!(w, zeros(N, K)) + AA(w, v, u, p, t) + @test w ≈ A * v + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + AA(w, v, u, p, t, α, β) + @test w ≈ α * (A * v) + β * orig_w end @testset "InvertibleOperator test" begin - u = rand(N, K) + # Vectors for testing + u = rand(N, K) # Update vector + v = rand(N, K) # Action vector + w = zeros(N, K) # Output vector + p = nothing t = 0 α = rand() @@ -79,26 +98,40 @@ end @test iscached(L) - @test L * u ≈ d .* u + # Test with new interface + @test L(v, u, p, t) ≈ d .* v @test L \ u ≈ d .\ u - v = rand(N, K) - @test mul!(v, L, u) ≈ d .* u - v = rand(N, K) - w = copy(v) - @test mul!(v, L, u, α, β) ≈ α * (d .* u) + β * w - - v = rand(N, K) - @test ldiv!(v, L, u) ≈ d .\ u + # Test in-place operations + copy!(w, zeros(N, K)) + L(w, v, u, p, t) + @test w ≈ d .* v + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + L(w, v, u, p, t, α, β) + @test w ≈ α * (d .* v) + β * orig_w + + # Test division operations + copy!(w, zeros(N, K)) + ldiv!(w, L, u) + @test w ≈ d .\ u + + # Existing test for in-place ldiv! v = copy(u) - @test ldiv!(L, v) ≈ d .\ u + ldiv!(L, v) + @test v ≈ d .\ u end @testset "MatrixOperator update test" begin - u = rand(N, K) + # Vectors for testing + u = rand(N, K) # Update vector + v = rand(N, K) # Action vector + w = zeros(N, K) # Output vector + p = rand(N) t = rand() - α = rand() β = rand() @@ -108,17 +141,33 @@ end @test !isconstant(L) + # Expected matrix after update A = p * p' - @test L(u, p, t) ≈ A * u - v = copy(u) - @test L(v, u, p, t) ≈ A * u - v = rand(N, K) - w = copy(v) - @test L(v, u, p, t, α, β) ≈ α * A * u + β * w + + # Test with new interface - same vector for update and action + @test L(u, u, p, t) ≈ A * u + + # Test with different vectors for update and action + @test L(v, u, p, t) ≈ A * v + + # Test in-place operation + copy!(w, zeros(N, K)) + L(w, v, u, p, t) + @test w ≈ A * v + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + L(w, v, u, p, t, α, β) + @test w ≈ α * (A * v) + β * orig_w end @testset "DiagonalOperator update test" begin - u = rand(N, K) + # Vectors for testing + u = rand(N, K) # Update vector + v = rand(N, K) # Action vector + w = zeros(N, K) # Output vector + p = rand(N) t = rand() α = rand() @@ -132,20 +181,35 @@ end @test issquare(D) @test islinear(D) - ans = Diagonal(p * t) * u - @test D(u, p, t) ≈ ans - v = copy(u) - @test D(v, u, p, t) ≈ ans - v = rand(N, K) - w = copy(v) - @test D(v, u, p, t, α, β) ≈ α * ans + β * w + # Expected result after update + expected = (p * t) .* v + + # Test with new interface - different vectors for update and action + @test D(v, u, p, t) ≈ expected + + # Test in-place operation + copy!(w, zeros(N, K)) + D(w, v, u, p, t) + @test w ≈ expected + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + D(w, v, u, p, t, α, β) + @test w ≈ α * expected + β * orig_w end @testset "Batched Diagonal Operator" begin - u = rand(N, K) + # Vectors for testing + u = rand(N, K) # Update vector + v = rand(N, K) # Action vector + w = zeros(N, K) # Output vector + d = rand(N, K) α = rand() β = rand() + p = nothing + t = 0.0 L = DiagonalOperator(d) @test isconstant(L) @@ -154,22 +218,39 @@ end @test issquare(L) @test islinear(L) - @test L * u ≈ d .* u - v = rand(N, K) - @test mul!(v, L, u) ≈ d .* u - v = rand(N, K) - w = copy(v) - @test mul!(v, L, u, α, β) ≈ α * (d .* u) + β * w - + # Test with new interface + @test L(v, u, p, t) ≈ d .* v + + # Test in-place operation + copy!(w, zeros(N, K)) + L(w, v, u, p, t) + @test w ≈ d .* v + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + L(w, v, u, p, t, α, β) + @test w ≈ α * (d .* v) + β * orig_w + + # Test division operations @test L \ u ≈ d .\ u - v = rand(N, K) - @test ldiv!(v, L, u) ≈ d .\ u + + copy!(w, zeros(N, K)) + ldiv!(w, L, u) + @test w ≈ d .\ u + + # Existing test for in-place ldiv! v = copy(u) - @test ldiv!(L, u) ≈ d .\ v + ldiv!(L, u) + @test u ≈ d .\ v end @testset "Batched DiagonalOperator update test" begin - u = rand(N, K) + # Vectors for testing + u = rand(N, K) # Update vector + v = rand(N, K) # Action vector + w = zeros(N, K) # Output vector + d = zeros(N, K) p = rand(N, K) t = rand() @@ -182,80 +263,138 @@ end @test issquare(D) @test islinear(D) - ans = (p * t) .* u - @test D(u, p, t) ≈ ans - v = copy(u) - @test D(v, u, p, t) ≈ ans + # Expected result after update + expected = (p * t) .* v + + # Test with new interface - different vectors for update and action + @test D(v, u, p, t) ≈ expected + + # Test in-place operation + copy!(w, zeros(N, K)) + D(w, v, u, p, t) + @test w ≈ expected end @testset "AffineOperator" begin - u = rand(N, K) + # Vectors for testing + u = rand(N, K) # Update vector + v = rand(N, K) # Action vector + w = zeros(N, K) # Output vector + A = rand(N, N) B = rand(N, N) D = Diagonal(A) b = rand(N, K) α = rand() β = rand() + p = nothing + t = 0.0 L = AffineOperator(MatrixOperator(A), MatrixOperator(B), b) @test isconstant(L) @test issquare(L) @test !islinear(L) - @test L * u ≈ A * u + B * b - v = rand(N, K) - @test mul!(v, L, u) ≈ A * u + B * b - v = rand(N, K) - w = copy(v) - @test mul!(v, L, u, α, β) ≈ α * (A * u + B * b) + β * w + # Test with new interface + expected = A * v + B * b + @test L(v, u, p, t) ≈ expected + + # Test in-place operation + copy!(w, zeros(N, K)) + L(w, v, u, p, t) + @test w ≈ expected + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + L(w, v, u, p, t, α, β) + @test w ≈ α * expected + β * orig_w L = AffineOperator(MatrixOperator(D), MatrixOperator(B), b) @test issquare(L) @test !islinear(L) + # Test division operations @test L \ u ≈ D \ (u - B * b) - v = rand(N, K) - @test ldiv!(v, L, u) ≈ D \ (u - B * b) + + copy!(w, zeros(N, K)) + ldiv!(w, L, u) + @test w ≈ D \ (u - B * b) + + # Existing test for in-place ldiv! v = copy(u) - @test ldiv!(L, u) ≈ D \ (v - B * b) + ldiv!(L, u) + @test u ≈ D \ (v - B * b) L = AddVector(b) @test issquare(L) @test !islinear(L) - @test L * u ≈ u + b + # Test with new interface + expected = v + b + @test L(v, u, p, t) ≈ expected @test L \ u ≈ u - b - v = rand(N, K) - @test mul!(v, L, u) ≈ u + b - v = rand(N, K) - w = copy(v) - @test mul!(v, L, u, α, β) ≈ α * (u + b) + β * w - v = rand(N, K) - @test ldiv!(v, L, u) ≈ u - b + + # Test in-place operation + copy!(w, zeros(N, K)) + L(w, v, u, p, t) + @test w ≈ expected + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + L(w, v, u, p, t, α, β) + @test w ≈ α * expected + β * orig_w + + # Test division operations + copy!(w, zeros(N, K)) + ldiv!(w, L, u) + @test w ≈ u - b + + # Existing test for in-place ldiv! v = copy(u) - @test ldiv!(L, u) ≈ v - b + ldiv!(L, u) + @test u ≈ v - b L = AddVector(MatrixOperator(B), b) @test issquare(L) @test !islinear(L) - @test L * u ≈ u + B * b + # Test with new interface + expected = v + B * b + @test L(v, u, p, t) ≈ expected @test L \ u ≈ u - B * b - v = rand(N, K) - @test mul!(v, L, u) ≈ u + B * b - v = rand(N, K) - w = copy(v) - @test mul!(v, L, u, α, β) ≈ α * (u + B * b) + β * w - v = rand(N, K) - @test ldiv!(v, L, u) ≈ u - B * b + + # Test in-place operation + copy!(w, zeros(N, K)) + L(w, v, u, p, t) + @test w ≈ expected + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + L(w, v, u, p, t, α, β) + @test w ≈ α * expected + β * orig_w + + # Test division operations + copy!(w, zeros(N, K)) + ldiv!(w, L, u) + @test w ≈ u - B * b + + # Existing test for in-place ldiv! v = copy(u) - @test ldiv!(L, u) ≈ v - B * b + ldiv!(L, u) + @test u ≈ v - B * b end @testset "AffineOperator update test" begin + # Vectors for testing + u = rand(N, K) # Update vector + v = rand(N, K) # Action vector + w = zeros(N, K) # Output vector + A = rand(N, N) B = rand(N, N) - u = rand(N, K) p = rand(N, K) t = rand() α = rand() @@ -267,155 +406,201 @@ end @test !isconstant(L) + # Expected updated bias and result b = p * t - ans = A * u + B * b - @test L(u, p, t) ≈ ans - v = copy(u) - @test L(v, u, p, t) ≈ ans - v = rand(N, K) - w = copy(v) - @test L(v, u, p, t, α, β) ≈ α * ans + β * w + expected = A * v + B * b + + # Test with new interface - different vectors for update and action + @test L(v, u, p, t) ≈ expected + + # Test in-place operation + copy!(w, zeros(N, K)) + L(w, v, u, p, t) + @test w ≈ expected + + # Test in-place with scaling + copy!(w, rand(N, K)) + orig_w = copy(w) + L(w, v, u, p, t, α, β) + @test w ≈ α * expected + β * orig_w end -@testset "TensorProductOperator" begin - for square in [false, true] #for K in [1, K] - m1, n1 = 3, 5 - m2, n2 = 7, 11 - m3, n3 = 13, 17 - - if square - n1, n2, n3 = m1, m2, m3 - end - - A = rand(m1, n1) - B = rand(m2, n2) - C = rand(m3, n3) - α = rand() - β = rand() - - AB = kron(A, B) - ABC = kron(A, B, C) - - # test Base.kron overload - # ensure kron(mat, mat) is not a TensorProductOperator - @test !isa(AB, AbstractSciMLOperator) - @test !isa(ABC, AbstractSciMLOperator) - - # test Base.kron overload - _A = rand(N, N) - @test kron(_A, MatrixOperator(_A)) isa TensorProductOperator - @test kron(MatrixOperator(_A), _A) isa TensorProductOperator - - @test kron(MatrixOperator(_A), MatrixOperator(_A)) isa TensorProductOperator - - # Inputs - u2 = rand(n1 * n2, K) - u3 = rand(n1 * n2 * n3, K) - # Outputs - v2 = rand(m1 * m2, K) - v3 = rand(m1 * m2 * m3, K) - - # Outputs - v2 = rand(m1 * m2, K) - v3 = rand(m1 * m2 * m3, K) - - opAB = TensorProductOperator(A, B) - opABC = TensorProductOperator(A, B, C) - - @test opAB isa TensorProductOperator - @test opABC isa TensorProductOperator - - @test isconstant(opAB) - @test isconstant(opABC) - - @test islinear(opAB) - @test islinear(opABC) - - if square - @test issquare(opAB) - @test issquare(opABC) - else - @test !issquare(opAB) - @test !issquare(opABC) - end - - @test AB ≈ convert(AbstractMatrix, opAB) - @test ABC ≈ convert(AbstractMatrix, opABC) - - # factorization tests - opAB_F = factorize(opAB) - opABC_F = factorize(opABC) - - @test isconstant(opAB_F) - @test isconstant(opABC_F) - - @test opAB_F isa TensorProductOperator - @test opABC_F isa TensorProductOperator - - @test AB ≈ convert(AbstractMatrix, opAB_F) - @test ABC ≈ convert(AbstractMatrix, opABC_F) - - @test AB * u2 ≈ opAB * u2 - @test ABC * u3 ≈ opABC * u3 - - @test AB \ v2 ≈ opAB \ v2 ≈ opAB_F \ v2 - @test ABC \ v3 ≈ opABC \ v3 ≈ opABC_F \ v3 - - @test !iscached(opAB) - @test !iscached(opABC) - - @test !iscached(opAB_F) - @test !iscached(opABC_F) - - opAB = cache_operator(opAB, u2) - opABC = cache_operator(opABC, u3) +@testset "TensorProductOperator, square = $square" for square in [false, true] + m1, n1 = 3, 5 + m2, n2 = 7, 11 + m3, n3 = 13, 17 - opAB_F = cache_operator(opAB_F, u2) - opABC_F = cache_operator(opABC_F, u3) - - @test iscached(opAB) - @test iscached(opABC) - - @test iscached(opAB_F) - @test iscached(opABC_F) - - N2 = n1 * n2 - N3 = n1 * n2 * n3 - M2 = m1 * m2 - M3 = m1 * m2 * m3 - - v2 = rand(M2, K) - @test mul!(v2, opAB, u2) ≈ AB * u2 - v3 = rand(M3, K) - @test mul!(v3, opABC, u3) ≈ ABC * u3 + if square + n1, n2, n3 = m1, m2, m3 + end + A = rand(m1, n1) + B = rand(m2, n2) + C = rand(m3, n3) + α = rand() + β = rand() + p = nothing + t = 0.0 + + AB = kron(A, B) + ABC = kron(A, B, C) + + # test Base.kron overload + # ensure kron(mat, mat) is not a TensorProductOperator + @test !isa(AB, AbstractSciMLOperator) + @test !isa(ABC, AbstractSciMLOperator) + + # test Base.kron overload + _A = rand(N, N) + @test kron(_A, MatrixOperator(_A)) isa TensorProductOperator + @test kron(MatrixOperator(_A), _A) isa TensorProductOperator + + @test kron(MatrixOperator(_A), MatrixOperator(_A)) isa TensorProductOperator + + # Inputs/Update vectors + u2 = rand(n1 * n2, K) + u3 = rand(n1 * n2 * n3, K) + + # Action vectors (same as update vectors initially) + v2 = copy(u2) + v3 = copy(u3) + + # Output vectors + w2 = zeros(m1 * m2, K) + w3 = zeros(m1 * m2 * m3, K) + + opAB = TensorProductOperator(A, B) + opABC = TensorProductOperator(A, B, C) + + @test opAB isa TensorProductOperator + @test opABC isa TensorProductOperator + + @test isconstant(opAB) + @test isconstant(opABC) + + @test islinear(opAB) + @test islinear(opABC) + + if square + @test issquare(opAB) + @test issquare(opABC) + else + @test !issquare(opAB) + @test !issquare(opABC) + end + + @test AB ≈ convert(AbstractMatrix, opAB) + @test ABC ≈ convert(AbstractMatrix, opABC) + + # factorization tests + opAB_F = factorize(opAB) + opABC_F = factorize(opABC) + + @test isconstant(opAB_F) + @test isconstant(opABC_F) + + @test opAB_F isa TensorProductOperator + @test opABC_F isa TensorProductOperator + + @test AB ≈ convert(AbstractMatrix, opAB_F) + @test ABC ≈ convert(AbstractMatrix, opABC_F) + + # Test with new interface + @test AB * v2 ≈ opAB(v2, u2, p, t) + @test ABC * v3 ≈ opABC(v3, u3, p, t) + + @test AB \ w2 ≈ opAB \ w2 + @test AB \ w2 ≈ opAB_F \ w2 + @test ABC \ w3 ≈ opABC \ w3 + @test ABC \ w3 ≈ opABC_F \ w3 + + @test !iscached(opAB) + @test !iscached(opABC) + + @test !iscached(opAB_F) + @test !iscached(opABC_F) + + opAB = cache_operator(opAB, u2) + opABC = cache_operator(opABC, u3) + + opAB_F = cache_operator(opAB_F, u2) + opABC_F = cache_operator(opABC_F, u3) + + @test iscached(opAB) + @test iscached(opABC) + + @test iscached(opAB_F) + @test iscached(opABC_F) + + N2 = n1 * n2 + N3 = n1 * n2 * n3 + M2 = m1 * m2 + M3 = m1 * m2 * m3 + + # Test in-place operations with new interface + v2 = rand(n1 * n2, K) # Action vector + w2 = zeros(M2, K) # Output vector + opAB(w2, v2, u2, p, t) + @test w2 ≈ AB * v2 + + v3 = rand(n1 * n2 * n3, K) # Action vector + w3 = zeros(M3, K) # Output vector + opABC(w3, v3, u3, p, t) + @test w3 ≈ ABC * v3 + + # Test in-place with scaling + v2 = rand(n1 * n2, K) # Action vector + w2 = rand(M2, K) # Output vector + orig_w2 = copy(w2) + opAB(w2, v2, u2, p, t, α, β) + @test w2 ≈ α * AB * v2 + β * orig_w2 + + v3 = rand(n1 * n2 * n3, K) # Action vector + w3 = rand(M3, K) # Output vector + orig_w3 = copy(w3) + opABC(w3, v3, u3, p, t, α, β) + @test w3 ≈ α * ABC * v3 + β * orig_w3 + + if square + # Test division operations with new interface + v2 = rand(M2, K) # Action vector (size of output space) + u2 = rand(N2, K) # Update vector (size of input space) + w2 = zeros(N2, K) # Output vector (size of input space) + + # ldiv! with new interface + ldiv!(w2, opAB_F, v2) + @test w2 ≈ AB \ v2 + + v3 = rand(M3, K) # Action vector + u3 = rand(N3, K) # Update vector + w3 = zeros(N3, K) # Output vector + ldiv!(w3, opABC_F, v3) + @test w3 ≈ ABC \ v3 + + # In-place ldiv! (original style) v2 = rand(M2, K) - w2 = copy(v2) - @test mul!(v2, opAB, u2, α, β) ≈ α * AB * u2 + β * w2 + u2 = copy(v2) + ldiv!(opAB_F, v2) + @test v2 ≈ AB \ u2 + v3 = rand(M3, K) - w3 = copy(v3) - @test mul!(v3, opABC, u3, α, β) ≈ α * ABC * u3 + β * w3 - - if square - u2 = rand(N2, K) - @test ldiv!(u2, opAB_F, v2) ≈ AB \ v2 - u3 = rand(N3, K) - @test ldiv!(u3, opABC_F, v3) ≈ ABC \ v3 - - v2 = copy(u2) - @test ldiv!(opAB_F, u2) ≈ AB \ v2 - v3 = copy(u3) - @test ldiv!(opABC_F, u3) ≈ ABC \ v3 - else # TODO - u2 = rand(N2, K) - if VERSION < v"1.9-" - @test_broken ldiv!(u2, opAB_F, v2) ≈ AB \ v2 - else - @test ldiv!(u2, opAB_F, v2) ≈ AB \ v2 - end - u3 = rand(N3, K) - @test_broken ldiv!(u3, opABC_F, v3) ≈ ABC \ v3 # errors + u3 = copy(v3) + ldiv!(opABC_F, v3) + @test v3 ≈ ABC \ u3 + else # TODO + v2 = rand(M2, K) # Action vector + u2 = rand(N2, K) # Update vector + w2 = zeros(N2, K) # Output vector + + if VERSION < v"1.9-" + @test_broken ldiv!(w2, opAB_F, v2) ≈ AB \ v2 + else + @test ldiv!(w2, opAB_F, v2) ≈ AB \ v2 end - end #end -end -# + + v3 = rand(M3, K) # Action vector + u3 = rand(N3, K) # Update vector + w3 = zeros(N3, K) # Output vector + @test_broken ldiv!(w3, opABC_F, v3) ≈ ABC \ v3 # errors + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 83a761ff..b79310cf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,22 +1,38 @@ -using SafeTestsets +using SafeTestsets, Test, Pkg +const GROUP = get(ENV, "GROUP", "All") + +function activate_downstream_env() + Pkg.activate("downstream") + Pkg.develop(PackageSpec(path = dirname(@__DIR__))) + Pkg.instantiate() +end @time begin - @time @safetestset "Scalar Operators" begin - include("scalar.jl") - end - @time @safetestset "Basic Operators" begin - include("basic.jl") - end - @time @safetestset "Matrix Operators" begin - include("matrix.jl") - end - @time @safetestset "Function Operator" begin - include("func.jl") - end - @time @safetestset "Full tests" begin - include("total.jl") - end - @time @safetestset "Zygote.jl" begin - include("zygote.jl") + @testset "SciMLOperators" begin + if GROUP == "All" || GROUP == "Core" + @time @safetestset "Scalar Operators" begin + include("scalar.jl") + end + @time @safetestset "Basic Operators" begin + include("basic.jl") + end + @time @safetestset "Matrix Operators" begin + include("matrix.jl") + end + @time @safetestset "Function Operator" begin + include("func.jl") + end + @time @safetestset "Full tests" begin + include("total.jl") + end + @time @safetestset "Zygote.jl" begin + include("zygote.jl") + end + elseif GROUP == "All" || GROUP == "Downstream" + activate_downstream_env() + @time @safetestset "AllocCheck" begin + include("downstream/alloccheck.jl") + end + end end end diff --git a/test/scalar.jl b/test/scalar.jl index ac9282f4..1bd64066 100644 --- a/test/scalar.jl +++ b/test/scalar.jl @@ -1,4 +1,3 @@ -# using SciMLOperators using SciMLOperators: AbstractSciMLScalarOperator, ComposedScalarOperator, @@ -8,13 +7,13 @@ using SciMLOperators: AbstractSciMLScalarOperator, AddedOperator, ScaledOperator -using LinearAlgebra, Random +using LinearAlgebra, Random, Test Random.seed!(0) N = 8 K = 12 -@testset "ScalarOperator" begin +@testset "ScalarOperator Basic Operations" begin a = rand() b = rand() x = rand() @@ -32,11 +31,13 @@ K = 12 @test size(α) == () @test isconstant(α) + # Original lmul!/rmul! tests v = copy(u) @test lmul!(α, u) ≈ v * x v = copy(u) @test rmul!(u, α) ≈ x * v + # Original mul!/ldiv! tests v = rand(N, K) @test mul!(v, α, u) ≈ u * x v = rand(N, K) @@ -48,32 +49,61 @@ K = 12 w = copy(u) @test ldiv!(α, u) ≈ w / x + # Original axpy! test X = rand(N, K) Y = rand(N, K) Z = copy(Y) - a = rand() - aa = ScalarOperator(a) - @test axpy!(aa, X, Y) ≈ a * X + Z + a_scalar = rand() + aa = ScalarOperator(a_scalar) + @test axpy!(aa, X, Y) ≈ a_scalar * X + Z + + # Tests with the new interface + v = copy(u) # Action vector + w = zeros(N, K) # Output vector + + # Test with new interface + result = α(v, u, nothing, 0.0) + @test result ≈ v * x + + # Test in-place operations + α(w, v, u, nothing, 0.0) + @test w ≈ v * x + + # Test in-place operations with scaling + orig_w = rand(N, K) + copy!(w, orig_w) + α(w, v, u, nothing, 0.0, a, b) + @test w ≈ a * (x * v) + b * orig_w +end - # Test that ScalarOperator's remain AbstractSciMLScalarOperator's under common ops +@testset "ScalarOperator Combinations" begin + x = rand() + α = ScalarOperator(x) + u = rand(N, K) # Update vector + v = rand(N, K) # Action vector + + # Test scalar operator combinations β = α + α @test β isa AddedScalarOperator - @test β * u ≈ x * u + x * u + @test β(v, u, nothing, 0.0) ≈ x * v + x * v + @test β * u ≈ x * u + x * u # Original style test @inferred convert(Float32, β) @test convert(Number, β) ≈ x + x β = α * α @test β isa ComposedScalarOperator - @test β * u ≈ x * x * u + @test β(v, u, nothing, 0.0) ≈ x * x * v + @test β * u ≈ x * x * u # Original style test @inferred convert(Float32, β) @test convert(Number, β) ≈ x * x β = inv(α) @test β isa InvertedScalarOperator - @test β * u ≈ 1 / x * u + @test β(v, u, nothing, 0.0) ≈ (1 / x) * v + @test β * u ≈ (1 / x) * u # Original style test @inferred convert(Float32, β) @test convert(Number, β) ≈ 1 / x - + β = α * inv(α) @test β isa ComposedScalarOperator @test β * u ≈ u @@ -85,22 +115,33 @@ K = 12 @test β * u ≈ u @inferred convert(Float32, β) @test convert(Number, β) ≈ true - + # Test combination with other operators for op in (MatrixOperator(rand(N, N)), SciMLOperators.IdentityOperator(N)) @test α + op isa SciMLOperators.AddedOperator @test (α + op) * u ≈ x * u + op * u + + L = α + op + @test L isa SciMLOperators.AddedOperator + @test L(v, u, nothing, 0.0) ≈ x * v + op * v + @test α * op isa SciMLOperators.ScaledOperator @test (α * op) * u ≈ x * (op * u) + + L = α * op + @test L isa SciMLOperators.ScaledOperator + @test L(v, u, nothing, 0.0) ≈ x * (op * v) + + # Division tests from original @test all(map(T -> (T isa SciMLOperators.ScaledOperator), (α / op, op / α, op \ α, α \ op))) @test (α / op) * u ≈ (op \ α) * u ≈ α * (op \ u) @test (op / α) * u ≈ (α \ op) * u ≈ 1 / α * op * u end - - # ensure composedscalaroperators doesn't nest - α = ScalarOperator(rand()) - L = α * (α * α) * α + + # Test for ComposedScalarOperator nesting (from original) + α_new = ScalarOperator(rand()) + L = α_new * (α_new * α_new) * α_new @test L isa ComposedScalarOperator for op in L.ops @test !isa(op, ComposedScalarOperator) @@ -109,24 +150,26 @@ end @testset "ScalarOperator scalar argument test" begin a = rand() - u = rand() - v = rand() + u = rand() # Update scalar + v = rand() # Action scalar p = nothing t = 0.0 α = ScalarOperator(a) - @test α(u, p, t) ≈ u * a - @test_throws ArgumentError α(v, u, p, t) - @test_throws ArgumentError α(v, u, p, t, 1, 2) + @test_throws MethodError α(u, p, t) ≈ u * a # Original style + @test α(v, u, p, t) ≈ v * a # New interface + @test_throws ArgumentError α(v, u, p, t, 1, 2) # Keep error test end @testset "ScalarOperator update test" begin - u = ones(N, K) - v = zeros(N, K) + u = ones(N, K) # Update vector + v = ones(N, K) # Action vector + w = zeros(N, K) # Output vector p = 2.0 t = 4.0 - a = rand() - b = rand() + + c = rand() + d = rand() α = ScalarOperator(0.0; update_func = (a, u, p, t) -> p) β = ScalarOperator(0.0; update_func = (a, u, p, t) -> t) @@ -140,42 +183,71 @@ end @test convert(Number, α) ≈ 0.0 @test convert(Number, β) ≈ 0.0 + # Test update_coefficients update_coefficients!(α, u, p, t) update_coefficients!(β, u, p, t) @test convert(Number, α) ≈ p @test convert(Number, β) ≈ t - @test α(u, p, t) ≈ p * u - v = rand(N, K) - @test α(v, u, p, t) ≈ p * u - v = rand(N, K) - w = copy(v) - @test α(v, u, p, t, a, b) ≈ a * p * u + b * w - - @test β(u, p, t) ≈ t * u - v = rand(N, K) - @test β(v, u, p, t) ≈ t * u - v = rand(N, K) - w = copy(v) - @test β(v, u, p, t, a, b) ≈ a * t * u + b * w - + # Original style tests + @test_throws MethodError α(u, p, t) ≈ p * u + @test_throws MethodError β(u, p, t) ≈ t * u + + # Tests with new interface + @test α(v, u, p, t) ≈ p * v + @test β(v, u, p, t) ≈ t * v + + # Test in-place with scaling + orig_w = rand(N, K) + copy!(w, orig_w) + α(w, v, u, p, t, c, d) + @test w ≈ c * p * v + d * orig_w + + # Retain original test with random vectors + v_rand = rand(N, K) + @test α(v_rand, u, p, t) ≈ p * v_rand + v_rand = rand(N, K) + w_rand = copy(v_rand) + @test_broken α(v_rand, u, p, t, c, d) ≈ c * p * u + d * w_rand + + # Test operator combinations num = α + 2 / β * 3 - 4 val = p + 2 / t * 3 - 4 - + @test convert(Number, num) ≈ val - # Test scalar operator which expects keyword argument to update, - # modeled in the style of a DiffEq W-operator. + # Test with keyword arguments γ = ScalarOperator(0.0; update_func = (args...; dtgamma) -> dtgamma, - accepted_kwargs = (:dtgamma,)) - + accepted_kwargs = (:dtgamma,)) + dtgamma = rand() - @test γ(u, p, t; dtgamma) ≈ dtgamma * u - @test γ(v, u, p, t; dtgamma) ≈ dtgamma * u - + # Original tests + @test_throws MethodError γ(u, p, t; dtgamma) ≈ dtgamma * u + + # New interface tests + @test γ(v, u, p, t; dtgamma) ≈ dtgamma * v + + # In-place test with keywords + w_test = zeros(N, K) + γ(w_test, v, u, p, t; dtgamma) + @test w_test ≈ dtgamma * v + γ_added = γ + α - @test γ_added(u, p, t; dtgamma) ≈ (dtgamma + p) * u - @test γ_added(v, u, p, t; dtgamma) ≈ (dtgamma + p) * u -end -# + # Original tests + @test_throws MethodError γ_added(u, p, t; dtgamma) ≈ (dtgamma + p) * u + + # New interface tests + @test γ_added(v, u, p, t; dtgamma) ≈ (dtgamma + p) * v + + # In-place test with keywords for combined operator + w_test = zeros(N, K) + γ_added(w_test, v, u, p, t; dtgamma) + @test w_test ≈ (dtgamma + p) * v + + # In-place test with scaling and keywords + w_test = rand(N, K) + w_orig = copy(w_test) + γ_added(w_test, v, u, p, t, c, d; dtgamma) + @test w_test ≈ c * (dtgamma + p) * v + d * w_orig +end \ No newline at end of file diff --git a/test/total.jl b/test/total.jl index 6530cb55..6033f5f9 100644 --- a/test/total.jl +++ b/test/total.jl @@ -20,11 +20,11 @@ K = 12 m = length(k) P = plan_rfft(x) - fwd(u, p, t) = P * u - bwd(u, p, t) = P \ u + fwd(v, u, p, t) = P * v + bwd(v, u, p, t) = P \ v - fwd(du, u, p, t) = mul!(du, P, u) - bwd(du, u, p, t) = ldiv!(du, P, u) + fwd(w, v, u, p, t) = mul!(w, P, v) + bwd(w, v, u, p, t) = ldiv!(w, P, v) ftr = FunctionOperator(fwd, x, im * k; T = ComplexF64, op_adjoint = bwd, @@ -39,18 +39,18 @@ K = 12 Dx = cache_operator(Dx, x) D2x = cache_operator(Dx * Dx, x) - u = @. sin(5x)cos(7x) - du = @. 5cos(5x)cos(7x) - 7sin(5x)sin(7x) - d2u = @. 5(-5sin(5x)cos(7x) - 7cos(5x)sin(7x)) + + v = @. sin(5x)cos(7x) + w = @. 5cos(5x)cos(7x) - 7sin(5x)sin(7x) + w2x = @. 5(-5sin(5x)cos(7x) - 7cos(5x)sin(7x)) + -7(5cos(5x)sin(7x) + 7sin(5x)cos(7x)) - @test ≈(Dx * u, du; atol = 1e-8) - @test ≈(D2x * u, d2u; atol = 1e-8) + @test ≈(Dx * v, w; atol = 1e-8) + @test ≈(D2x * v, w2x; atol = 1e-8) - v = copy(u) - @test ≈(mul!(v, D2x, u), d2u; atol = 1e-8) - v = copy(u) - @test ≈(mul!(v, Dx, u), du; atol = 1e-8) + w2 = zero(w) + @test ≈(mul!(w2, D2x, v), w2x; atol = 1e-8) + w2 = zero(w) + @test ≈(mul!(w2, Dx, v), w; atol = 1e-8) itr = inv(ftr) ftt = ftr' @@ -84,7 +84,8 @@ end @testset "Operator Algebra" begin N2 = N * N - u = rand(N2, K) + v = rand(N2, K) + u = rand() p = rand() t = rand() @@ -95,9 +96,9 @@ end # FunctionOp _C = rand(N, N) |> Symmetric - f(u, p, t) = _C * u - f(v, u, p, t) = mul!(v, _C, u) - C = FunctionOperator(f, zeros(N); batch = true, issymmetric = true, p = p) + f(v, u, p, t) = _C * v + f(w, v, u, p, t) = mul!(w, _C, v) + C = FunctionOperator(f, zeros(N); batch = true, issymmetric = true, p = p, u = u) # Introduce update function for D dependent on kwarg "matrix" D = MatrixOperator(zeros(N, N); @@ -121,83 +122,93 @@ end DD = Diagonal([D1, D2]) op = TT' * DD * TT - op = cache_operator(op, u) + op = cache_operator(op, v) # Update operator @test_nowarn update_coefficients!(op, u, p, t; diag, matrix) + # Form dense operator manually dense_T1 = kron(A, p * ones(N, N)) dense_T2 = kron(_C, (p * t) .* matrix) dense_DD = Diagonal(vcat(p * ones(N2), p * t * diag)) dense_op = hcat(dense_T1', dense_T2') * dense_DD * vcat(dense_T1, dense_T2) + # Test correctness of op - @test op * u ≈ dense_op * u + @test op * v ≈ dense_op * v + # Test consistency with three-arg mul! - v = rand(N2, K) - @test mul!(v, op, u) ≈ op * u + w = rand(N2, K) + @test mul!(w, op, v) ≈ op * v + # Test consistency with in-place five-arg mul! - v = rand(N2, K) - w = copy(v) - @test mul!(v, op, u, α, β) ≈ α * (op * u) + β * w - # Test consistency with operator application form - @test op(u, p, t; diag, matrix) ≈ op * u - v = rand(N2, K) - @test op(v, u, p, t; diag, matrix) ≈ op * u + w = rand(N2, K) + w2 = copy(w) + @test mul!(w, op, v, α, β) ≈ α * (op * v) + β * w2 + + # Create a fresh operator for each test + op_fresh = TT' * DD * TT + op_fresh = cache_operator(op_fresh, v) + # Use in-place update directly in test + result1 = similar(v) + mul!(result1, op_fresh, v) + update_coefficients!(op_fresh, u, p, t; diag, matrix) + @test result1 ≈ dense_op * v end + @testset "Resize! test" begin M1 = 4 M2 = 12 - u = rand(N) - u1 = rand(M1) - u2 = rand(M2) + v = rand(N) + v1 = rand(M1) + v2 = rand(M2) - f(u, p, t) = 2 * u - f(v, u, p, t) = (copy!(v, u); lmul!(2, v)) + f(v, u, p, t) = 2 * v + f(w, v, u, p, t) = (copy!(w, v); lmul!(2, w)) - fi(u, p, t) = 0.5 * u - fi(v, u, p, t) = (copy!(v, u); lmul!(0.5, v)) + fi(v, u, p, t) = 0.5 * v + fi(w, v, u, p, t) = (copy!(w, v); lmul!(0.5, w)) - F = FunctionOperator(f, u, u; islinear = true, op_inverse = fi, issymmetric = true) + F = FunctionOperator(f, v, v; islinear = true, op_inverse = fi, issymmetric = true) - multest(L, u) = @test mul!(zero(u), L, u) ≈ L * u + multest(L, v) = @test mul!(zero(v), L, v) ≈ L * v - function multest(L::SciMLOperators.AdjointOperator, u) - @test mul!(adjoint(zero(u)), adjoint(u), L) ≈ adjoint(u) * L + function multest(L::SciMLOperators.AdjointOperator, v) + @test mul!(adjoint(zero(v)), adjoint(v), L) ≈ adjoint(v) * L end - function multest(L::SciMLOperators.TransposedOperator, u) - @test mul!(transpose(zero(u)), transpose(u), L) ≈ transpose(u) * L + function multest(L::SciMLOperators.TransposedOperator, v) + @test mul!(transpose(zero(v)), transpose(v), L) ≈ transpose(v) * L end - function multest(L::SciMLOperators.InvertedOperator, u) - @test ldiv!(zero(u), L, u) ≈ L \ u + function multest(L::SciMLOperators.InvertedOperator, v) + @test ldiv!(zero(v), L, v) ≈ L \ v end for (L, LT) in ((F, FunctionOperator), (F + F, SciMLOperators.AddedOperator), (F * 2, SciMLOperators.ScaledOperator), (F ∘ F, SciMLOperators.ComposedOperator), - (AffineOperator(F, F, u), AffineOperator), + (AffineOperator(F, F, v), AffineOperator), (SciMLOperators.AdjointOperator(F), SciMLOperators.AdjointOperator), (SciMLOperators.TransposedOperator(F), SciMLOperators.TransposedOperator), (SciMLOperators.InvertedOperator(F), SciMLOperators.InvertedOperator), (SciMLOperators.InvertibleOperator(F, F), SciMLOperators.InvertibleOperator)) L = deepcopy(L) - L = cache_operator(L, u) + L = cache_operator(L, v) @test L isa LT @test size(L) == (N, N) - multest(L, u) + multest(L, v) resize!(L, M1) @test size(L) == (M1, M1) - multest(L, u1) + multest(L, v1) resize!(L, M2) @test size(L) == (M2, M2) - multest(L, u2) + multest(L, v2) end # InvertedOperator diff --git a/test/zygote.jl b/test/zygote.jl index 8e3c5232..b5d3defa 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -38,8 +38,8 @@ L_mi = MatrixOperator(zeros(N, N); update_func = inv_update_func) L_aff = AffineOperator(L_mat, L_mat, zeros(N, K); update_func = vec_update_func) L_sca = α * L_mat L_inv = InvertibleOperator(L_mat, L_mi) -L_fun = FunctionOperator((u, p, t) -> Diagonal(p) * u, u0, u0; batch = true, - op_inverse = (u, p, t) -> inv(Diagonal(p)) * u) +L_fun = FunctionOperator((v, u, p, t) -> Diagonal(p) * v, u0, u0; batch = true, + op_inverse = (v, u, p, t) -> inv(Diagonal(p)) * v) Ti = MatrixOperator(zeros(n, n); update_func = tsr_update_func) To = deepcopy(Ti) @@ -66,19 +66,26 @@ for (LType, L) in ((IdentityOperator, IdentityOperator(N)), (AddedScalarOperator, α + α), (ComposedScalarOperator, α * α)) @assert L isa LType + + # Cache the operator for efficient application + L_cached = cache_operator(L, u0) + # Updated loss function using the new interface: + # v is the action vector, u0 is the update vector loss_mul = function (p) v = Diagonal(p) * u0 - w = L(v, p, t) + # Use new interface: L(v, u, p, t) + w = L_cached(v, u0, p, t) l = sum(w) end loss_div = function (p) v = Diagonal(p) * u0 - - L = update_coefficients(L, v, p, t) - w = L \ v - + + # Update coefficients first, then apply inverse + L_updated = update_coefficients(L_cached, u0, p, t) + w = L_updated \ v + l = sum(w) end @@ -99,4 +106,4 @@ for (LType, L) in ((IdentityOperator, IdentityOperator(N)), @test !isa(g_div, Nothing) end end -end +end \ No newline at end of file