@@ -109,14 +109,118 @@ end
109109# operator evaluation interface
110110# ##
111111
112- function (L:: AbstractSciMLOperator )(u, p, t; kwargs... )
113- update_coefficients (L, u, p, t; kwargs... ) * u
112+ abstract type OperatorMethodTag end
113+ struct OutOfPlaceTag <: OperatorMethodTag end
114+ struct InPlaceTag <: OperatorMethodTag end
115+ struct ScaledInPlaceTag <: OperatorMethodTag end
116+
117+ """
118+ $SIGNATURES
119+
120+ Apply the operator L to the vector x, after updating the coefficients of L using u_update.
121+
122+ This method is out-of-place, i.e., it allocates a new vector for the result.
123+
124+ # Arguments
125+ - `x`: The vector to which the operator is applied
126+ - `p`: Parameter object
127+ - `t`: Time parameter
128+ - `u_update`: Vector used to update the operator coefficients (defaults to `x`)
129+ - `kwargs...`: Additional keyword arguments for the update function
130+
131+ # Returns
132+ The result of applying the operator to x
133+
134+ # Example
135+ ```julia
136+ L = MatrixOperator(zeros(4,4); update_func = some_update_func)
137+ x = rand(4)
138+ p = some_params
139+ t = 1.0
140+ u_update = rand(4) # Some reference state for updating coefficients
141+
142+ # Update using u_update, then apply operator to x
143+ v = L(x, p, t, u_update)
144+ ```
145+ """
146+ function (L:: AbstractSciMLOperator{T} )(x, p, t, u_update= x; kwargs... ) where {T}
147+ _check_device_match (x, u_update)
148+ _check_size_compatibility (L, u_update, x)
149+ update_coefficients (L, u_update, p, t; kwargs... ) * x
114150end
115- function (L:: AbstractSciMLOperator )(du, u, p, t; kwargs... )
116- (update_coefficients! (L, u, p, t; kwargs... ); mul! (du, L, u))
151+
152+ """
153+ $SIGNATURES
154+
155+ Apply the operator L to the vector u in-place, storing the result in du,
156+ after updating the coefficients of L using u_update.
157+
158+ # Arguments
159+ - `du`: The output vector where the result is stored
160+ - `u`: The vector to which the operator is applied
161+ - `p`: Parameter object
162+ - `t`: Time parameter
163+ - `u_update`: Vector used to update the operator coefficients (defaults to `u`)
164+ - `kwargs...`: Additional keyword arguments for the update function
165+
166+ # Example
167+ ```julia
168+ L = MatrixOperator(zeros(4,4); update_func = some_update_func)
169+ u = rand(4)
170+ du = similar(u)
171+ p = some_params
172+ t = 1.0
173+ u_update = rand(4) # Some reference state for updating coefficients
174+
175+ # Update using u_update, then apply operator to u, storing in du
176+ L(du, u, p, t, u_update)
177+ ```
178+ """
179+ function (L:: AbstractSciMLOperator{T} )(du:: AbstractArray , u:: AbstractArray , p, t, u_update= u; kwargs... ) where {T}
180+ _check_device_match (du, u, u_update)
181+ _check_size_compatibility (L, u_update, u, du)
182+ update_coefficients! (L, u_update, p, t; kwargs... )
183+ mul! (du, L, u)
184+ return du # Explicitly return du
117185end
118- function (L:: AbstractSciMLOperator )(du, u, p, t, α, β; kwargs... )
119- (update_coefficients! (L, u, p, t; kwargs... ); mul! (du, L, u, α, β))
186+
187+ """
188+ $SIGNATURES
189+
190+ Apply the operator L to vector u with scaling factors α and β, computing du = α*L*u + β*du,
191+ after updating the coefficients of L using u_update.
192+
193+ # Arguments
194+ - `du`: The output vector where the result is accumulated
195+ - `u`: The vector to which the operator is applied
196+ - `p`: Parameter object
197+ - `t`: Time parameter
198+ - `α`: Scaling factor for L*u
199+ - `β`: Scaling factor for the existing value in du
200+ - `u_update`: Vector used to update the operator coefficients (defaults to `u`)
201+ - `kwargs...`: Additional keyword arguments for the update function
202+
203+ # Example
204+ ```julia
205+ L = MatrixOperator(zeros(4,4); update_func = some_update_func)
206+ u = rand(4)
207+ du = rand(4)
208+ p = some_params
209+ t = 1.0
210+ α = 2.0
211+ β = 1.0
212+ u_update = rand(4) # Some reference state for updating coefficients
213+
214+ # Compute du = α*L*u + β*du
215+ L(du, u, p, t, α, β, u_update)
216+ ```
217+ """
218+ function (L:: AbstractSciMLOperator{T} )(du:: AbstractArray , u:: AbstractArray , p, t, α, β, u_update= u; kwargs... ) where {T}
219+ _check_device_match (du, u, u_update)
220+ _check_size_compatibility (L, u_update, u, du)
221+ update_coefficients! (L, u_update, p, t; kwargs... )
222+ mul! (du, L, u, α, β)
223+ return du # Explicitly return du
120224end
121225
122226function (L:: AbstractSciMLOperator )(du:: Number , u:: Number , p, t, args... ; kwargs... )
@@ -125,6 +229,83 @@ function (L::AbstractSciMLOperator)(du::Number, u::Number, p, t, args...; kwargs
125229 throw (ArgumentError (msg))
126230end
127231
232+ """
233+ @private
234+
235+ Check that all vectors are on the same device (CPU/GPU).
236+ This function is a no-op in the standard implementation but can be
237+ extended by packages that provide GPU support.
238+ """
239+ function _check_device_match (args... )
240+ # Default implementation - no device checking in base package
241+ # This would be extended by GPU-supporting packages
242+ nothing
243+ end
244+
245+ """
246+ @private
247+
248+ Verify that the sizes of vectors are compatible with the operator and with each other.
249+ """
250+ function _check_size_compatibility (L:: AbstractSciMLOperator , u_update, u, du= nothing )
251+ # Special case for scalar operators which have size() = ()
252+ if L isa AbstractSciMLScalarOperator
253+ # Scalar operators can operate on any size inputs
254+ # Just check batch dimensions if present
255+ if u isa AbstractMatrix && u_update isa AbstractMatrix
256+ if size (u, 2 ) != size (u_update, 2 )
257+ throw (DimensionMismatch (
258+ " Batch dimension of u ($(size (u, 2 )) ) must match batch dimension of u_update ($(size (u_update, 2 )) )" ))
259+ end
260+ end
261+
262+ if du != = nothing && u isa AbstractMatrix && du isa AbstractMatrix
263+ if size (u, 2 ) != size (du, 2 )
264+ throw (DimensionMismatch (
265+ " Batch dimension of u ($(size (u, 2 )) ) must match batch dimension of du ($(size (du, 2 )) )" ))
266+ end
267+ end
268+
269+ return nothing
270+ end
271+
272+ # For regular operators with dimensions
273+ # Verify u_update has compatible size for updating operator
274+ if size (u_update, 1 ) != size (L, 2 )
275+ throw (DimensionMismatch (
276+ " Size of u_update ($(size (u_update, 1 )) ) must match the input dimension of operator ($(size (L, 2 )) )" ))
277+ end
278+
279+ # Verify u has compatible size for operator application
280+ if size (u, 1 ) != size (L, 2 )
281+ throw (DimensionMismatch (
282+ " Size of u ($(size (u, 1 )) ) must match the input dimension of operator ($(size (L, 2 )) )" ))
283+ end
284+
285+ # If du is provided, verify it has compatible size for storing the result
286+ if du != = nothing && size (du, 1 ) != size (L, 1 )
287+ throw (DimensionMismatch (
288+ " Size of du ($(size (du, 1 )) ) must match the output dimension of operator ($(size (L, 1 )) )" ))
289+ end
290+
291+ # Verify batch dimensions match if present
292+ if u isa AbstractMatrix && u_update isa AbstractMatrix
293+ if size (u, 2 ) != size (u_update, 2 )
294+ throw (DimensionMismatch (
295+ " Batch dimension of u ($(size (u, 2 )) ) must match batch dimension of u_update ($(size (u_update, 2 )) )" ))
296+ end
297+ end
298+
299+ if du != = nothing && u isa AbstractMatrix && du isa AbstractMatrix
300+ if size (u, 2 ) != size (du, 2 )
301+ throw (DimensionMismatch (
302+ " Batch dimension of u ($(size (u, 2 )) ) must match batch dimension of du ($(size (du, 2 )) )" ))
303+ end
304+ end
305+
306+ nothing
307+ end
308+
128309# ##
129310# operator caching interface
130311# ##
0 commit comments