Skip to content

Commit ff55ed8

Browse files
committed
SciMLOperators Different defining vectors
1 parent 29ea228 commit ff55ed8

File tree

10 files changed

+384
-101
lines changed

10 files changed

+384
-101
lines changed

src/SciMLOperators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
$(README)
33
"""
4-
module SciMLOperators
4+
module SciMLOperators# Temporary fix while we debug method overwriting issues
55

66
using DocStringExtensions
77

src/batch.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
8282
end
8383
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag)))
8484

85-
function update_coefficients(L::BatchedDiagonalOperator, u, p, t; kwargs...)
86-
@reset L.diag = L.update_func(L.diag, u, p, t; kwargs...)
85+
function update_coefficients(L::BatchedDiagonalOperator, u_update, p, t; kwargs...)
86+
@reset L.diag = L.update_func(L.diag, u_update, p, t; kwargs...)
8787
end
8888

89-
function update_coefficients!(L::BatchedDiagonalOperator, u, p, t; kwargs...)
90-
L.update_func!(L.diag, u, p, t; kwargs...)
89+
function update_coefficients!(L::BatchedDiagonalOperator, u_update, p, t; kwargs...)
90+
L.update_func!(L.diag, u_update, p, t; kwargs...)
9191

9292
nothing
9393
end

src/func.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ end
357357
end
358358
@inline __and_val(vs...) = mapreduce(_unwrap_val, *, vs)
359359

360-
function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)
360+
function update_coefficients(L::FunctionOperator, u_update, p, t; kwargs...)
361361

362362
# update p, t
363363
L = set_p(L, p)
@@ -370,14 +370,14 @@ function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)
370370

371371
isconstant(L) && return L
372372

373-
L = set_op(L, update_coefficients(L.op, u, p, t; filtered_kwargs...))
374-
L = set_op_adjoint(L, update_coefficients(L.op_adjoint, u, p, t; filtered_kwargs...))
375-
L = set_op_inverse(L, update_coefficients(L.op_inverse, u, p, t; filtered_kwargs...))
373+
L = set_op(L, update_coefficients(L.op, u_update, p, t; filtered_kwargs...))
374+
L = set_op_adjoint(L, update_coefficients(L.op_adjoint, u_update, p, t; filtered_kwargs...))
375+
L = set_op_inverse(L, update_coefficients(L.op_inverse, u_update, p, t; filtered_kwargs...))
376376
L = set_op_adjoint_inverse(L,
377-
update_coefficients(L.op_adjoint_inverse, u, p, t; filtered_kwargs...))
377+
update_coefficients(L.op_adjoint_inverse, u_update, p, t; filtered_kwargs...))
378378
end
379379

380-
function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
380+
function update_coefficients!(L::FunctionOperator, u_update, p, t; kwargs...)
381381

382382
# update p, t
383383
L.p = p
@@ -390,7 +390,7 @@ function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
390390
isconstant(L) && return
391391

392392
for op in getops(L)
393-
update_coefficients!(op, u, p, t; filtered_kwargs...)
393+
update_coefficients!(op, u_update, p, t; filtered_kwargs...)
394394
end
395395

396396
nothing

src/interface.jl

Lines changed: 187 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,118 @@ end
109109
# operator evaluation interface
110110
###
111111

112-
function (L::AbstractSciMLOperator)(u, p, t; kwargs...)
113-
update_coefficients(L, u, p, t; kwargs...) * u
112+
abstract type OperatorMethodTag end
113+
struct OutOfPlaceTag <: OperatorMethodTag end
114+
struct InPlaceTag <: OperatorMethodTag end
115+
struct ScaledInPlaceTag <: OperatorMethodTag end
116+
117+
"""
118+
$SIGNATURES
119+
120+
Apply the operator L to the vector x, after updating the coefficients of L using u_update.
121+
122+
This method is out-of-place, i.e., it allocates a new vector for the result.
123+
124+
# Arguments
125+
- `x`: The vector to which the operator is applied
126+
- `p`: Parameter object
127+
- `t`: Time parameter
128+
- `u_update`: Vector used to update the operator coefficients (defaults to `x`)
129+
- `kwargs...`: Additional keyword arguments for the update function
130+
131+
# Returns
132+
The result of applying the operator to x
133+
134+
# Example
135+
```julia
136+
L = MatrixOperator(zeros(4,4); update_func = some_update_func)
137+
x = rand(4)
138+
p = some_params
139+
t = 1.0
140+
u_update = rand(4) # Some reference state for updating coefficients
141+
142+
# Update using u_update, then apply operator to x
143+
v = L(x, p, t, u_update)
144+
```
145+
"""
146+
function (L::AbstractSciMLOperator{T})(x, p, t, u_update=x; kwargs...) where {T}
147+
_check_device_match(x, u_update)
148+
_check_size_compatibility(L, u_update, x)
149+
update_coefficients(L, u_update, p, t; kwargs...) * x
114150
end
115-
function (L::AbstractSciMLOperator)(du, u, p, t; kwargs...)
116-
(update_coefficients!(L, u, p, t; kwargs...); mul!(du, L, u))
151+
152+
"""
153+
$SIGNATURES
154+
155+
Apply the operator L to the vector u in-place, storing the result in du,
156+
after updating the coefficients of L using u_update.
157+
158+
# Arguments
159+
- `du`: The output vector where the result is stored
160+
- `u`: The vector to which the operator is applied
161+
- `p`: Parameter object
162+
- `t`: Time parameter
163+
- `u_update`: Vector used to update the operator coefficients (defaults to `u`)
164+
- `kwargs...`: Additional keyword arguments for the update function
165+
166+
# Example
167+
```julia
168+
L = MatrixOperator(zeros(4,4); update_func = some_update_func)
169+
u = rand(4)
170+
du = similar(u)
171+
p = some_params
172+
t = 1.0
173+
u_update = rand(4) # Some reference state for updating coefficients
174+
175+
# Update using u_update, then apply operator to u, storing in du
176+
L(du, u, p, t, u_update)
177+
```
178+
"""
179+
function (L::AbstractSciMLOperator{T})(du::AbstractArray, u::AbstractArray, p, t, u_update=u; kwargs...) where {T}
180+
_check_device_match(du, u, u_update)
181+
_check_size_compatibility(L, u_update, u, du)
182+
update_coefficients!(L, u_update, p, t; kwargs...)
183+
mul!(du, L, u)
184+
return du # Explicitly return du
117185
end
118-
function (L::AbstractSciMLOperator)(du, u, p, t, α, β; kwargs...)
119-
(update_coefficients!(L, u, p, t; kwargs...); mul!(du, L, u, α, β))
186+
187+
"""
188+
$SIGNATURES
189+
190+
Apply the operator L to vector u with scaling factors α and β, computing du = α*L*u + β*du,
191+
after updating the coefficients of L using u_update.
192+
193+
# Arguments
194+
- `du`: The output vector where the result is accumulated
195+
- `u`: The vector to which the operator is applied
196+
- `p`: Parameter object
197+
- `t`: Time parameter
198+
- `α`: Scaling factor for L*u
199+
- `β`: Scaling factor for the existing value in du
200+
- `u_update`: Vector used to update the operator coefficients (defaults to `u`)
201+
- `kwargs...`: Additional keyword arguments for the update function
202+
203+
# Example
204+
```julia
205+
L = MatrixOperator(zeros(4,4); update_func = some_update_func)
206+
u = rand(4)
207+
du = rand(4)
208+
p = some_params
209+
t = 1.0
210+
α = 2.0
211+
β = 1.0
212+
u_update = rand(4) # Some reference state for updating coefficients
213+
214+
# Compute du = α*L*u + β*du
215+
L(du, u, p, t, α, β, u_update)
216+
```
217+
"""
218+
function (L::AbstractSciMLOperator{T})(du::AbstractArray, u::AbstractArray, p, t, α, β, u_update=u; kwargs...) where {T}
219+
_check_device_match(du, u, u_update)
220+
_check_size_compatibility(L, u_update, u, du)
221+
update_coefficients!(L, u_update, p, t; kwargs...)
222+
mul!(du, L, u, α, β)
223+
return du # Explicitly return du
120224
end
121225

122226
function (L::AbstractSciMLOperator)(du::Number, u::Number, p, t, args...; kwargs...)
@@ -125,6 +229,83 @@ function (L::AbstractSciMLOperator)(du::Number, u::Number, p, t, args...; kwargs
125229
throw(ArgumentError(msg))
126230
end
127231

232+
"""
233+
@private
234+
235+
Check that all vectors are on the same device (CPU/GPU).
236+
This function is a no-op in the standard implementation but can be
237+
extended by packages that provide GPU support.
238+
"""
239+
function _check_device_match(args...)
240+
# Default implementation - no device checking in base package
241+
# This would be extended by GPU-supporting packages
242+
nothing
243+
end
244+
245+
"""
246+
@private
247+
248+
Verify that the sizes of vectors are compatible with the operator and with each other.
249+
"""
250+
function _check_size_compatibility(L::AbstractSciMLOperator, u_update, u, du=nothing)
251+
# Special case for scalar operators which have size() = ()
252+
if L isa AbstractSciMLScalarOperator
253+
# Scalar operators can operate on any size inputs
254+
# Just check batch dimensions if present
255+
if u isa AbstractMatrix && u_update isa AbstractMatrix
256+
if size(u, 2) != size(u_update, 2)
257+
throw(DimensionMismatch(
258+
"Batch dimension of u ($(size(u, 2))) must match batch dimension of u_update ($(size(u_update, 2)))"))
259+
end
260+
end
261+
262+
if du !== nothing && u isa AbstractMatrix && du isa AbstractMatrix
263+
if size(u, 2) != size(du, 2)
264+
throw(DimensionMismatch(
265+
"Batch dimension of u ($(size(u, 2))) must match batch dimension of du ($(size(du, 2)))"))
266+
end
267+
end
268+
269+
return nothing
270+
end
271+
272+
# For regular operators with dimensions
273+
# Verify u_update has compatible size for updating operator
274+
if size(u_update, 1) != size(L, 2)
275+
throw(DimensionMismatch(
276+
"Size of u_update ($(size(u_update, 1))) must match the input dimension of operator ($(size(L, 2)))"))
277+
end
278+
279+
# Verify u has compatible size for operator application
280+
if size(u, 1) != size(L, 2)
281+
throw(DimensionMismatch(
282+
"Size of u ($(size(u, 1))) must match the input dimension of operator ($(size(L, 2)))"))
283+
end
284+
285+
# If du is provided, verify it has compatible size for storing the result
286+
if du !== nothing && size(du, 1) != size(L, 1)
287+
throw(DimensionMismatch(
288+
"Size of du ($(size(du, 1))) must match the output dimension of operator ($(size(L, 1)))"))
289+
end
290+
291+
# Verify batch dimensions match if present
292+
if u isa AbstractMatrix && u_update isa AbstractMatrix
293+
if size(u, 2) != size(u_update, 2)
294+
throw(DimensionMismatch(
295+
"Batch dimension of u ($(size(u, 2))) must match batch dimension of u_update ($(size(u_update, 2)))"))
296+
end
297+
end
298+
299+
if du !== nothing && u isa AbstractMatrix && du isa AbstractMatrix
300+
if size(u, 2) != size(du, 2)
301+
throw(DimensionMismatch(
302+
"Batch dimension of u ($(size(u, 2))) must match batch dimension of du ($(size(du, 2)))"))
303+
end
304+
end
305+
306+
nothing
307+
end
308+
128309
###
129310
# operator caching interface
130311
###

src/left.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,18 @@ for (op, LType, VType) in ((:adjoint, :AdjointOperator, :AbstractAdjointVecOrMat
157157
u
158158
end
159159
end
160+
161+
###
162+
# Update interfaces for AdjointOperator and TransposedOperator
163+
###
164+
165+
function update_coefficients(L::Union{AdjointOperator,TransposedOperator}, u_update, p, t; kwargs...)
166+
@reset L.L = update_coefficients(L.L, u_update, p, t; kwargs...)
167+
L
168+
end
169+
170+
function update_coefficients!(L::Union{AdjointOperator,TransposedOperator}, u_update, p, t; kwargs...)
171+
update_coefficients!(L.L, u_update, p, t; kwargs...)
172+
nothing
173+
end
160174
#

src/matrix.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,13 @@ function isconstant(L::MatrixOperator)
155155
update_func_isconstant(L.update_func) & update_func_isconstant(L.update_func!)
156156
end
157157

158-
function update_coefficients(L::MatrixOperator, u, p, t; kwargs...)
159-
@reset L.A = L.update_func(L.A, u, p, t; kwargs...)
158+
function update_coefficients(L::MatrixOperator, u_update, p, t; kwargs...)
159+
@reset L.A = L.update_func(L.A, u_update, p, t; kwargs...)
160+
L
160161
end
161162

162-
function update_coefficients!(L::MatrixOperator, u, p, t; kwargs...)
163-
L.update_func!(L.A, u, p, t; kwargs...)
163+
function update_coefficients!(L::MatrixOperator, u_update, p, t; kwargs...)
164+
L.update_func!(L.A, u_update, p, t; kwargs...)
164165

165166
nothing
166167
end

0 commit comments

Comments
 (0)