Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ff55ed8
SciMLOperators Different defining vectors
divital-coder May 5, 2025
41108ec
Operators in work, modificataions -> operator update interface, opera…
divital-coder May 9, 2025
44f7121
Updated interface for operators, addition of new and modified tests
divital-coder May 12, 2025
17a6a58
Update SciMLOperators.jl
ChrisRackauckas May 12, 2025
06083e9
Update interface.jl
ChrisRackauckas May 12, 2025
81adab6
Modified tests, fixed left.jl as per reviews
divital-coder May 15, 2025
7cd3adf
Fixed Zygote tests
divital-coder May 15, 2025
d38b165
Update runtests.jl
ChrisRackauckas May 15, 2025
51776eb
Added original test lines back
divital-coder May 16, 2025
9926df0
Update test/func.jl
ChrisRackauckas May 16, 2025
7df820d
Update test/func.jl
ChrisRackauckas May 16, 2025
b43e25a
Update test/func.jl
ChrisRackauckas May 16, 2025
d830399
Update test/func.jl
ChrisRackauckas May 16, 2025
ea686c6
Update test/scalar.jl
ChrisRackauckas May 16, 2025
6dc91ef
Update test/scalar.jl
ChrisRackauckas May 16, 2025
3ab5da1
Update test/scalar.jl
ChrisRackauckas May 16, 2025
09c80a2
Update test/scalar.jl
ChrisRackauckas May 16, 2025
efdb2e7
Update test/total.jl
ChrisRackauckas May 16, 2025
3dc0ae0
fix a lot of docstrings
ChrisRackauckas May 16, 2025
7451c30
one more docstring
ChrisRackauckas May 16, 2025
f7a39df
update docs
ChrisRackauckas May 16, 2025
5892a2f
let doc tests run
ChrisRackauckas May 16, 2025
d3a36fc
fix variable names
ChrisRackauckas May 16, 2025
f6f11aa
some more renaming and fix matrix tests
ChrisRackauckas May 16, 2025
9d4d0b5
Fix FunctionOperator definition
ChrisRackauckas May 16, 2025
b70b102
get tests passing
ChrisRackauckas May 17, 2025
b7d2132
fix doc builds
ChrisRackauckas May 17, 2025
fcc3833
fix typo
ChrisRackauckas May 17, 2025
7376381
fix docs
ChrisRackauckas May 17, 2025
3cd098a
alloc test passes on later versions
ChrisRackauckas May 17, 2025
500b89f
fix alloc tests
ChrisRackauckas May 17, 2025
7d105da
split allocs checks
ChrisRackauckas May 17, 2025
098264b
fix new test setup
ChrisRackauckas May 17, 2025
1c47fe6
alloc tests throw
ChrisRackauckas May 17, 2025
596f239
fix pre
ChrisRackauckas May 17, 2025
91e4312
finalize
ChrisRackauckas May 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/SciMLOperators.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
$(README)
"""
module SciMLOperators
module SciMLOperators# Temporary fix while we debug method overwriting issues

using DocStringExtensions

Expand Down
8 changes: 4 additions & 4 deletions src/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
end
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag)))

function update_coefficients(L::BatchedDiagonalOperator, u, p, t; kwargs...)
@reset L.diag = L.update_func(L.diag, u, p, t; kwargs...)
function update_coefficients(L::BatchedDiagonalOperator, u_update, p, t; kwargs...)
@reset L.diag = L.update_func(L.diag, u_update, p, t; kwargs...)
end

function update_coefficients!(L::BatchedDiagonalOperator, u, p, t; kwargs...)
L.update_func!(L.diag, u, p, t; kwargs...)
function update_coefficients!(L::BatchedDiagonalOperator, u_update, p, t; kwargs...)
L.update_func!(L.diag, u_update, p, t; kwargs...)

nothing
end
Expand Down
14 changes: 7 additions & 7 deletions src/func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ end
end
@inline __and_val(vs...) = mapreduce(_unwrap_val, *, vs)

function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)
function update_coefficients(L::FunctionOperator, u_update, p, t; kwargs...)

# update p, t
L = set_p(L, p)
Expand All @@ -370,14 +370,14 @@ function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)

isconstant(L) && return L

L = set_op(L, update_coefficients(L.op, u, p, t; filtered_kwargs...))
L = set_op_adjoint(L, update_coefficients(L.op_adjoint, u, p, t; filtered_kwargs...))
L = set_op_inverse(L, update_coefficients(L.op_inverse, u, p, t; filtered_kwargs...))
L = set_op(L, update_coefficients(L.op, u_update, p, t; filtered_kwargs...))
L = set_op_adjoint(L, update_coefficients(L.op_adjoint, u_update, p, t; filtered_kwargs...))
L = set_op_inverse(L, update_coefficients(L.op_inverse, u_update, p, t; filtered_kwargs...))
L = set_op_adjoint_inverse(L,
update_coefficients(L.op_adjoint_inverse, u, p, t; filtered_kwargs...))
update_coefficients(L.op_adjoint_inverse, u_update, p, t; filtered_kwargs...))
end

function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
function update_coefficients!(L::FunctionOperator, u_update, p, t; kwargs...)

# update p, t
L.p = p
Expand All @@ -390,7 +390,7 @@ function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
isconstant(L) && return

for op in getops(L)
update_coefficients!(op, u, p, t; filtered_kwargs...)
update_coefficients!(op, u_update, p, t; filtered_kwargs...)
end

nothing
Expand Down
193 changes: 187 additions & 6 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,118 @@ end
# operator evaluation interface
###

function (L::AbstractSciMLOperator)(u, p, t; kwargs...)
update_coefficients(L, u, p, t; kwargs...) * u
abstract type OperatorMethodTag end
struct OutOfPlaceTag <: OperatorMethodTag end
struct InPlaceTag <: OperatorMethodTag end
struct ScaledInPlaceTag <: OperatorMethodTag end

"""
$SIGNATURES

Apply the operator L to the vector x, after updating the coefficients of L using u_update.

This method is out-of-place, i.e., it allocates a new vector for the result.

# Arguments
- `x`: The vector to which the operator is applied
- `p`: Parameter object
- `t`: Time parameter
- `u_update`: Vector used to update the operator coefficients (defaults to `x`)
- `kwargs...`: Additional keyword arguments for the update function

# Returns
The result of applying the operator to x

# Example
```julia
L = MatrixOperator(zeros(4,4); update_func = some_update_func)
x = rand(4)
p = some_params
t = 1.0
u_update = rand(4) # Some reference state for updating coefficients

# Update using u_update, then apply operator to x
v = L(x, p, t, u_update)
```
"""
function (L::AbstractSciMLOperator{T})(x, p, t, u_update=x; kwargs...) where {T}
_check_device_match(x, u_update)
_check_size_compatibility(L, u_update, x)
update_coefficients(L, u_update, p, t; kwargs...) * x
end
function (L::AbstractSciMLOperator)(du, u, p, t; kwargs...)
(update_coefficients!(L, u, p, t; kwargs...); mul!(du, L, u))

"""
$SIGNATURES

Apply the operator L to the vector u in-place, storing the result in du,
after updating the coefficients of L using u_update.

# Arguments
- `du`: The output vector where the result is stored
- `u`: The vector to which the operator is applied
- `p`: Parameter object
- `t`: Time parameter
- `u_update`: Vector used to update the operator coefficients (defaults to `u`)
- `kwargs...`: Additional keyword arguments for the update function

# Example
```julia
L = MatrixOperator(zeros(4,4); update_func = some_update_func)
u = rand(4)
du = similar(u)
p = some_params
t = 1.0
u_update = rand(4) # Some reference state for updating coefficients

# Update using u_update, then apply operator to u, storing in du
L(du, u, p, t, u_update)
```
"""
function (L::AbstractSciMLOperator{T})(du::AbstractArray, u::AbstractArray, p, t, u_update=u; kwargs...) where {T}
_check_device_match(du, u, u_update)
_check_size_compatibility(L, u_update, u, du)
update_coefficients!(L, u_update, p, t; kwargs...)
mul!(du, L, u)
return du # Explicitly return du
end
function (L::AbstractSciMLOperator)(du, u, p, t, α, β; kwargs...)
(update_coefficients!(L, u, p, t; kwargs...); mul!(du, L, u, α, β))

"""
$SIGNATURES

Apply the operator L to vector u with scaling factors α and β, computing du = α*L*u + β*du,
after updating the coefficients of L using u_update.

# Arguments
- `du`: The output vector where the result is accumulated
- `u`: The vector to which the operator is applied
- `p`: Parameter object
- `t`: Time parameter
- `α`: Scaling factor for L*u
- `β`: Scaling factor for the existing value in du
- `u_update`: Vector used to update the operator coefficients (defaults to `u`)
- `kwargs...`: Additional keyword arguments for the update function

# Example
```julia
L = MatrixOperator(zeros(4,4); update_func = some_update_func)
u = rand(4)
du = rand(4)
p = some_params
t = 1.0
α = 2.0
β = 1.0
u_update = rand(4) # Some reference state for updating coefficients

# Compute du = α*L*u + β*du
L(du, u, p, t, α, β, u_update)
```
"""
function (L::AbstractSciMLOperator{T})(du::AbstractArray, u::AbstractArray, p, t, α, β, u_update=u; kwargs...) where {T}
_check_device_match(du, u, u_update)
_check_size_compatibility(L, u_update, u, du)
update_coefficients!(L, u_update, p, t; kwargs...)
mul!(du, L, u, α, β)
return du # Explicitly return du
end

function (L::AbstractSciMLOperator)(du::Number, u::Number, p, t, args...; kwargs...)
Expand All @@ -125,6 +229,83 @@ function (L::AbstractSciMLOperator)(du::Number, u::Number, p, t, args...; kwargs
throw(ArgumentError(msg))
end

"""
@private

Check that all vectors are on the same device (CPU/GPU).
This function is a no-op in the standard implementation but can be
extended by packages that provide GPU support.
"""
function _check_device_match(args...)
# Default implementation - no device checking in base package
# This would be extended by GPU-supporting packages
nothing
end

"""
@private

Verify that the sizes of vectors are compatible with the operator and with each other.
"""
function _check_size_compatibility(L::AbstractSciMLOperator, u_update, u, du=nothing)
# Special case for scalar operators which have size() = ()
if L isa AbstractSciMLScalarOperator
# Scalar operators can operate on any size inputs
# Just check batch dimensions if present
if u isa AbstractMatrix && u_update isa AbstractMatrix
if size(u, 2) != size(u_update, 2)
throw(DimensionMismatch(
"Batch dimension of u ($(size(u, 2))) must match batch dimension of u_update ($(size(u_update, 2)))"))
end
end

if du !== nothing && u isa AbstractMatrix && du isa AbstractMatrix
if size(u, 2) != size(du, 2)
throw(DimensionMismatch(
"Batch dimension of u ($(size(u, 2))) must match batch dimension of du ($(size(du, 2)))"))
end
end

return nothing
end

# For regular operators with dimensions
# Verify u_update has compatible size for updating operator
if size(u_update, 1) != size(L, 2)
throw(DimensionMismatch(
"Size of u_update ($(size(u_update, 1))) must match the input dimension of operator ($(size(L, 2)))"))
end

# Verify u has compatible size for operator application
if size(u, 1) != size(L, 2)
throw(DimensionMismatch(
"Size of u ($(size(u, 1))) must match the input dimension of operator ($(size(L, 2)))"))
end

# If du is provided, verify it has compatible size for storing the result
if du !== nothing && size(du, 1) != size(L, 1)
throw(DimensionMismatch(
"Size of du ($(size(du, 1))) must match the output dimension of operator ($(size(L, 1)))"))
end

# Verify batch dimensions match if present
if u isa AbstractMatrix && u_update isa AbstractMatrix
if size(u, 2) != size(u_update, 2)
throw(DimensionMismatch(
"Batch dimension of u ($(size(u, 2))) must match batch dimension of u_update ($(size(u_update, 2)))"))
end
end

if du !== nothing && u isa AbstractMatrix && du isa AbstractMatrix
if size(u, 2) != size(du, 2)
throw(DimensionMismatch(
"Batch dimension of u ($(size(u, 2))) must match batch dimension of du ($(size(du, 2)))"))
end
end

nothing
end

###
# operator caching interface
###
Expand Down
14 changes: 14 additions & 0 deletions src/left.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,18 @@ for (op, LType, VType) in ((:adjoint, :AdjointOperator, :AbstractAdjointVecOrMat
u
end
end

###
# Update interfaces for AdjointOperator and TransposedOperator
###

function update_coefficients(L::Union{AdjointOperator,TransposedOperator}, u_update, p, t; kwargs...)
@reset L.L = update_coefficients(L.L, u_update, p, t; kwargs...)
L
end

function update_coefficients!(L::Union{AdjointOperator,TransposedOperator}, u_update, p, t; kwargs...)
update_coefficients!(L.L, u_update, p, t; kwargs...)
nothing
end
#
9 changes: 5 additions & 4 deletions src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ function isconstant(L::MatrixOperator)
update_func_isconstant(L.update_func) & update_func_isconstant(L.update_func!)
end

function update_coefficients(L::MatrixOperator, u, p, t; kwargs...)
@reset L.A = L.update_func(L.A, u, p, t; kwargs...)
function update_coefficients(L::MatrixOperator, u_update, p, t; kwargs...)
@reset L.A = L.update_func(L.A, u_update, p, t; kwargs...)
L
end

function update_coefficients!(L::MatrixOperator, u, p, t; kwargs...)
L.update_func!(L.A, u, p, t; kwargs...)
function update_coefficients!(L::MatrixOperator, u_update, p, t; kwargs...)
L.update_func!(L.A, u_update, p, t; kwargs...)

nothing
end
Expand Down
Loading