Skip to content

Commit 41108ec

Browse files
committed
Operators in work, modificataions -> operator update interface, operator application and evaluation interfaces, tests
1 parent ff55ed8 commit 41108ec

18 files changed

+2005
-381
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLOperators"
22
uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
33
authors = ["Vedant Puri <[email protected]>"]
4-
version = "0.3.13"
4+
version = "0.4.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/SciMLOperators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
$(README)
33
"""
4-
module SciMLOperators# Temporary fix while we debug method overwriting issues
4+
module SciMLOperators
55

66
using DocStringExtensions
77

src/basic.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,26 @@ for T in SCALINGNUMBERTYPES
218218
end
219219
end
220220

221+
# Special cases for constant scalars. These simplify the structure when applicable
222+
for T in SCALINGNUMBERTYPES[2:end]
223+
@eval function Base.:*::$T, L::ScaledOperator)
224+
isconstant(L.λ) && return ScaledOperator* L.λ, L.L)
225+
return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule
226+
end
227+
@eval function Base.:*(L::ScaledOperator, α::$T)
228+
isconstant(L.λ) && return ScaledOperator* L.λ, L.L)
229+
return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule
230+
end
231+
@eval function Base.:*::$T, L::MatrixOperator)
232+
isconstant(L) && return MatrixOperator* L.A)
233+
return ScaledOperator(α, L) # Going back to the generic case
234+
end
235+
@eval function Base.:*(L::MatrixOperator, α::$T)
236+
isconstant(L) && return MatrixOperator* L.A)
237+
return ScaledOperator(α, L) # Going back to the generic case
238+
end
239+
end
240+
221241
Base.:-(L::AbstractSciMLOperator) = ScaledOperator(-true, L)
222242
Base.:+(L::AbstractSciMLOperator) = L
223243

src/batch.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
8282
end
8383
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag)))
8484

85-
function update_coefficients(L::BatchedDiagonalOperator, u_update, p, t; kwargs...)
86-
@reset L.diag = L.update_func(L.diag, u_update, p, t; kwargs...)
85+
function update_coefficients(L::BatchedDiagonalOperator, u, p, t; kwargs...)
86+
@reset L.diag = L.update_func(L.diag, u, p, t; kwargs...)
8787
end
8888

89-
function update_coefficients!(L::BatchedDiagonalOperator, u_update, p, t; kwargs...)
90-
L.update_func!(L.diag, u_update, p, t; kwargs...)
89+
function update_coefficients!(L::BatchedDiagonalOperator, u, p, t; kwargs...)
90+
L.update_func!(L.diag, u, p, t; kwargs...)
9191

9292
nothing
9393
end

src/func.jl

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ end
357357
end
358358
@inline __and_val(vs...) = mapreduce(_unwrap_val, *, vs)
359359

360-
function update_coefficients(L::FunctionOperator, u_update, p, t; kwargs...)
360+
function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)
361361

362362
# update p, t
363363
L = set_p(L, p)
@@ -370,14 +370,14 @@ function update_coefficients(L::FunctionOperator, u_update, p, t; kwargs...)
370370

371371
isconstant(L) && return L
372372

373-
L = set_op(L, update_coefficients(L.op, u_update, p, t; filtered_kwargs...))
374-
L = set_op_adjoint(L, update_coefficients(L.op_adjoint, u_update, p, t; filtered_kwargs...))
375-
L = set_op_inverse(L, update_coefficients(L.op_inverse, u_update, p, t; filtered_kwargs...))
373+
L = set_op(L, update_coefficients(L.op, u, p, t; filtered_kwargs...))
374+
L = set_op_adjoint(L, update_coefficients(L.op_adjoint, u, p, t; filtered_kwargs...))
375+
L = set_op_inverse(L, update_coefficients(L.op_inverse, u, p, t; filtered_kwargs...))
376376
L = set_op_adjoint_inverse(L,
377-
update_coefficients(L.op_adjoint_inverse, u_update, p, t; filtered_kwargs...))
377+
update_coefficients(L.op_adjoint_inverse, u, p, t; filtered_kwargs...))
378378
end
379379

380-
function update_coefficients!(L::FunctionOperator, u_update, p, t; kwargs...)
380+
function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
381381

382382
# update p, t
383383
L.p = p
@@ -390,7 +390,7 @@ function update_coefficients!(L::FunctionOperator, u_update, p, t; kwargs...)
390390
isconstant(L) && return
391391

392392
for op in getops(L)
393-
update_coefficients!(op, u_update, p, t; filtered_kwargs...)
393+
update_coefficients!(op, u, p, t; filtered_kwargs...)
394394
end
395395

396396
nothing
@@ -782,4 +782,68 @@ end
782782
function LinearAlgebra.ldiv!(L::FunctionOperator{false}, u::AbstractArray)
783783
@error "LinearAlgebra.ldiv! not defined for out-of-place $L"
784784
end
785+
786+
# Out-of-place: v is action vector, u is update vector
787+
function (L::FunctionOperator)(v::AbstractArray, u, p, t; kwargs...)
788+
L = update_coefficients(L, u, p, t; kwargs...)
789+
_sizecheck(L, v, nothing)
790+
V, _, vec_output = _unvec(L, v, nothing)
791+
792+
# Apply the operator to action vector v after updating with u
793+
if L.traits.outofplace
794+
result = L.op(V, L.p, L.t; L.traits.kwargs...)
795+
return vec_output ? vec(result) : result
796+
else
797+
# For operators without out-of-place methods, use their in-place methods with a temporary
798+
Co = similar(V)
799+
L.op(Co, V, L.p, L.t; L.traits.kwargs...)
800+
return vec_output ? vec(Co) : Co
801+
end
802+
end
803+
804+
# In-place: w is destination, v is action vector, u is update vector
805+
function (L::FunctionOperator)(w::AbstractArray, v::AbstractArray, u, p, t; kwargs...)
806+
update_coefficients!(L, u, p, t; kwargs...)
807+
808+
# Check dimensions
809+
_sizecheck(L, v, w)
810+
V, W, _ = _unvec(L, v, w)
811+
812+
# Apply the operator in-place to action vector v after updating with u
813+
if L.traits.isinplace
814+
L.op(W, V, L.p, L.t; L.traits.kwargs...)
815+
else
816+
# For operators without in-place methods, use their out-of-place methods
817+
result = L.op(V, L.p, L.t; L.traits.kwargs...)
818+
copyto!(W, result)
819+
end
820+
821+
return w
822+
end
823+
824+
# In-place with scaling: w = α*(L*v) + β*w
825+
function (L::FunctionOperator)(w::AbstractArray, v::AbstractArray, u, p, t, α, β; kwargs...)
826+
update_coefficients!(L, u, p, t; kwargs...)
827+
828+
# Check dimensions
829+
_sizecheck(L, v, w)
830+
V, W, _ = _unvec(L, v, w)
831+
832+
# Apply the operator in-place to action vector v with scaling
833+
if L.traits.isinplace && L.traits.has_mul5
834+
# Direct 5-arg mul! if supported
835+
L.op(W, V, L.p, L.t, α, β; L.traits.kwargs...)
836+
elseif L.traits.isinplace
837+
# Use temporary for regular in-place
838+
temp = copy(W)
839+
L.op(W, V, L.p, L.t; L.traits.kwargs...)
840+
axpby!(β, temp, α, W)
841+
else
842+
# Out-of-place with scaling
843+
result = L.op(V, L.p, L.t; L.traits.kwargs...)
844+
axpby!(β, W, α, result)
845+
end
846+
847+
return w
848+
end
785849
#

src/interface.jl

Lines changed: 11 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -108,204 +108,25 @@ end
108108
###
109109
# operator evaluation interface
110110
###
111-
112-
abstract type OperatorMethodTag end
113-
struct OutOfPlaceTag <: OperatorMethodTag end
114-
struct InPlaceTag <: OperatorMethodTag end
115-
struct ScaledInPlaceTag <: OperatorMethodTag end
116-
117-
"""
118-
$SIGNATURES
119-
120-
Apply the operator L to the vector x, after updating the coefficients of L using u_update.
121-
122-
This method is out-of-place, i.e., it allocates a new vector for the result.
123-
124-
# Arguments
125-
- `x`: The vector to which the operator is applied
126-
- `p`: Parameter object
127-
- `t`: Time parameter
128-
- `u_update`: Vector used to update the operator coefficients (defaults to `x`)
129-
- `kwargs...`: Additional keyword arguments for the update function
130-
131-
# Returns
132-
The result of applying the operator to x
133-
134-
# Example
135-
```julia
136-
L = MatrixOperator(zeros(4,4); update_func = some_update_func)
137-
x = rand(4)
138-
p = some_params
139-
t = 1.0
140-
u_update = rand(4) # Some reference state for updating coefficients
141-
142-
# Update using u_update, then apply operator to x
143-
v = L(x, p, t, u_update)
144-
```
145-
"""
146-
function (L::AbstractSciMLOperator{T})(x, p, t, u_update=x; kwargs...) where {T}
147-
_check_device_match(x, u_update)
148-
_check_size_compatibility(L, u_update, x)
149-
update_coefficients(L, u_update, p, t; kwargs...) * x
111+
# Out-of-place: v is action vector, u is update vector
112+
function (L::AbstractSciMLOperator)(v, u, p, t; kwargs...)
113+
update_coefficients(L, u, p, t; kwargs...) * v
150114
end
151-
152-
"""
153-
$SIGNATURES
154-
155-
Apply the operator L to the vector u in-place, storing the result in du,
156-
after updating the coefficients of L using u_update.
157-
158-
# Arguments
159-
- `du`: The output vector where the result is stored
160-
- `u`: The vector to which the operator is applied
161-
- `p`: Parameter object
162-
- `t`: Time parameter
163-
- `u_update`: Vector used to update the operator coefficients (defaults to `u`)
164-
- `kwargs...`: Additional keyword arguments for the update function
165-
166-
# Example
167-
```julia
168-
L = MatrixOperator(zeros(4,4); update_func = some_update_func)
169-
u = rand(4)
170-
du = similar(u)
171-
p = some_params
172-
t = 1.0
173-
u_update = rand(4) # Some reference state for updating coefficients
174-
175-
# Update using u_update, then apply operator to u, storing in du
176-
L(du, u, p, t, u_update)
177-
```
178-
"""
179-
function (L::AbstractSciMLOperator{T})(du::AbstractArray, u::AbstractArray, p, t, u_update=u; kwargs...) where {T}
180-
_check_device_match(du, u, u_update)
181-
_check_size_compatibility(L, u_update, u, du)
182-
update_coefficients!(L, u_update, p, t; kwargs...)
183-
mul!(du, L, u)
184-
return du # Explicitly return du
115+
# In-place: w is destination, v is action vector, u is update vector
116+
function (L::AbstractSciMLOperator)(w, v, u, p, t; kwargs...)
117+
(update_coefficients!(L, u, p, t; kwargs...); mul!(w, L, v))
185118
end
186-
187-
"""
188-
$SIGNATURES
189-
190-
Apply the operator L to vector u with scaling factors α and β, computing du = α*L*u + β*du,
191-
after updating the coefficients of L using u_update.
192-
193-
# Arguments
194-
- `du`: The output vector where the result is accumulated
195-
- `u`: The vector to which the operator is applied
196-
- `p`: Parameter object
197-
- `t`: Time parameter
198-
- `α`: Scaling factor for L*u
199-
- `β`: Scaling factor for the existing value in du
200-
- `u_update`: Vector used to update the operator coefficients (defaults to `u`)
201-
- `kwargs...`: Additional keyword arguments for the update function
202-
203-
# Example
204-
```julia
205-
L = MatrixOperator(zeros(4,4); update_func = some_update_func)
206-
u = rand(4)
207-
du = rand(4)
208-
p = some_params
209-
t = 1.0
210-
α = 2.0
211-
β = 1.0
212-
u_update = rand(4) # Some reference state for updating coefficients
213-
214-
# Compute du = α*L*u + β*du
215-
L(du, u, p, t, α, β, u_update)
216-
```
217-
"""
218-
function (L::AbstractSciMLOperator{T})(du::AbstractArray, u::AbstractArray, p, t, α, β, u_update=u; kwargs...) where {T}
219-
_check_device_match(du, u, u_update)
220-
_check_size_compatibility(L, u_update, u, du)
221-
update_coefficients!(L, u_update, p, t; kwargs...)
222-
mul!(du, L, u, α, β)
223-
return du # Explicitly return du
119+
# In-place with scaling: w = α*(L*v) + β*w
120+
function (L::AbstractSciMLOperator)(w, v, u, p, t, α, β; kwargs...)
121+
(update_coefficients!(L, u, p, t; kwargs...); mul!(w, L, v, α, β))
224122
end
225123

226-
function (L::AbstractSciMLOperator)(du::Number, u::Number, p, t, args...; kwargs...)
227-
msg = """Nonallocating L(v, u, p, t) type methods are not available for
124+
function (L::AbstractSciMLOperator)(w::Number, v::Number, u, p, t, args...; kwargs...)
125+
msg = """Nonallocating L(w, v, u, p, t) type methods are not available for
228126
subtypes of `Number`."""
229127
throw(ArgumentError(msg))
230128
end
231129

232-
"""
233-
@private
234-
235-
Check that all vectors are on the same device (CPU/GPU).
236-
This function is a no-op in the standard implementation but can be
237-
extended by packages that provide GPU support.
238-
"""
239-
function _check_device_match(args...)
240-
# Default implementation - no device checking in base package
241-
# This would be extended by GPU-supporting packages
242-
nothing
243-
end
244-
245-
"""
246-
@private
247-
248-
Verify that the sizes of vectors are compatible with the operator and with each other.
249-
"""
250-
function _check_size_compatibility(L::AbstractSciMLOperator, u_update, u, du=nothing)
251-
# Special case for scalar operators which have size() = ()
252-
if L isa AbstractSciMLScalarOperator
253-
# Scalar operators can operate on any size inputs
254-
# Just check batch dimensions if present
255-
if u isa AbstractMatrix && u_update isa AbstractMatrix
256-
if size(u, 2) != size(u_update, 2)
257-
throw(DimensionMismatch(
258-
"Batch dimension of u ($(size(u, 2))) must match batch dimension of u_update ($(size(u_update, 2)))"))
259-
end
260-
end
261-
262-
if du !== nothing && u isa AbstractMatrix && du isa AbstractMatrix
263-
if size(u, 2) != size(du, 2)
264-
throw(DimensionMismatch(
265-
"Batch dimension of u ($(size(u, 2))) must match batch dimension of du ($(size(du, 2)))"))
266-
end
267-
end
268-
269-
return nothing
270-
end
271-
272-
# For regular operators with dimensions
273-
# Verify u_update has compatible size for updating operator
274-
if size(u_update, 1) != size(L, 2)
275-
throw(DimensionMismatch(
276-
"Size of u_update ($(size(u_update, 1))) must match the input dimension of operator ($(size(L, 2)))"))
277-
end
278-
279-
# Verify u has compatible size for operator application
280-
if size(u, 1) != size(L, 2)
281-
throw(DimensionMismatch(
282-
"Size of u ($(size(u, 1))) must match the input dimension of operator ($(size(L, 2)))"))
283-
end
284-
285-
# If du is provided, verify it has compatible size for storing the result
286-
if du !== nothing && size(du, 1) != size(L, 1)
287-
throw(DimensionMismatch(
288-
"Size of du ($(size(du, 1))) must match the output dimension of operator ($(size(L, 1)))"))
289-
end
290-
291-
# Verify batch dimensions match if present
292-
if u isa AbstractMatrix && u_update isa AbstractMatrix
293-
if size(u, 2) != size(u_update, 2)
294-
throw(DimensionMismatch(
295-
"Batch dimension of u ($(size(u, 2))) must match batch dimension of u_update ($(size(u_update, 2)))"))
296-
end
297-
end
298-
299-
if du !== nothing && u isa AbstractMatrix && du isa AbstractMatrix
300-
if size(u, 2) != size(du, 2)
301-
throw(DimensionMismatch(
302-
"Batch dimension of u ($(size(u, 2))) must match batch dimension of du ($(size(du, 2)))"))
303-
end
304-
end
305-
306-
nothing
307-
end
308-
309130
###
310131
# operator caching interface
311132
###

0 commit comments

Comments
 (0)