@@ -108,204 +108,25 @@ end
108108# ##
109109# operator evaluation interface
110110# ##
111-
112- abstract type OperatorMethodTag end
113- struct OutOfPlaceTag <: OperatorMethodTag end
114- struct InPlaceTag <: OperatorMethodTag end
115- struct ScaledInPlaceTag <: OperatorMethodTag end
116-
117- """
118- $SIGNATURES
119-
120- Apply the operator L to the vector x, after updating the coefficients of L using u_update.
121-
122- This method is out-of-place, i.e., it allocates a new vector for the result.
123-
124- # Arguments
125- - `x`: The vector to which the operator is applied
126- - `p`: Parameter object
127- - `t`: Time parameter
128- - `u_update`: Vector used to update the operator coefficients (defaults to `x`)
129- - `kwargs...`: Additional keyword arguments for the update function
130-
131- # Returns
132- The result of applying the operator to x
133-
134- # Example
135- ```julia
136- L = MatrixOperator(zeros(4,4); update_func = some_update_func)
137- x = rand(4)
138- p = some_params
139- t = 1.0
140- u_update = rand(4) # Some reference state for updating coefficients
141-
142- # Update using u_update, then apply operator to x
143- v = L(x, p, t, u_update)
144- ```
145- """
146- function (L:: AbstractSciMLOperator{T} )(x, p, t, u_update= x; kwargs... ) where {T}
147- _check_device_match (x, u_update)
148- _check_size_compatibility (L, u_update, x)
149- update_coefficients (L, u_update, p, t; kwargs... ) * x
111+ # Out-of-place: v is action vector, u is update vector
112+ function (L:: AbstractSciMLOperator )(v, u, p, t; kwargs... )
113+ update_coefficients (L, u, p, t; kwargs... ) * v
150114end
151-
152- """
153- $SIGNATURES
154-
155- Apply the operator L to the vector u in-place, storing the result in du,
156- after updating the coefficients of L using u_update.
157-
158- # Arguments
159- - `du`: The output vector where the result is stored
160- - `u`: The vector to which the operator is applied
161- - `p`: Parameter object
162- - `t`: Time parameter
163- - `u_update`: Vector used to update the operator coefficients (defaults to `u`)
164- - `kwargs...`: Additional keyword arguments for the update function
165-
166- # Example
167- ```julia
168- L = MatrixOperator(zeros(4,4); update_func = some_update_func)
169- u = rand(4)
170- du = similar(u)
171- p = some_params
172- t = 1.0
173- u_update = rand(4) # Some reference state for updating coefficients
174-
175- # Update using u_update, then apply operator to u, storing in du
176- L(du, u, p, t, u_update)
177- ```
178- """
179- function (L:: AbstractSciMLOperator{T} )(du:: AbstractArray , u:: AbstractArray , p, t, u_update= u; kwargs... ) where {T}
180- _check_device_match (du, u, u_update)
181- _check_size_compatibility (L, u_update, u, du)
182- update_coefficients! (L, u_update, p, t; kwargs... )
183- mul! (du, L, u)
184- return du # Explicitly return du
115+ # In-place: w is destination, v is action vector, u is update vector
116+ function (L:: AbstractSciMLOperator )(w, v, u, p, t; kwargs... )
117+ (update_coefficients! (L, u, p, t; kwargs... ); mul! (w, L, v))
185118end
186-
187- """
188- $SIGNATURES
189-
190- Apply the operator L to vector u with scaling factors α and β, computing du = α*L*u + β*du,
191- after updating the coefficients of L using u_update.
192-
193- # Arguments
194- - `du`: The output vector where the result is accumulated
195- - `u`: The vector to which the operator is applied
196- - `p`: Parameter object
197- - `t`: Time parameter
198- - `α`: Scaling factor for L*u
199- - `β`: Scaling factor for the existing value in du
200- - `u_update`: Vector used to update the operator coefficients (defaults to `u`)
201- - `kwargs...`: Additional keyword arguments for the update function
202-
203- # Example
204- ```julia
205- L = MatrixOperator(zeros(4,4); update_func = some_update_func)
206- u = rand(4)
207- du = rand(4)
208- p = some_params
209- t = 1.0
210- α = 2.0
211- β = 1.0
212- u_update = rand(4) # Some reference state for updating coefficients
213-
214- # Compute du = α*L*u + β*du
215- L(du, u, p, t, α, β, u_update)
216- ```
217- """
218- function (L:: AbstractSciMLOperator{T} )(du:: AbstractArray , u:: AbstractArray , p, t, α, β, u_update= u; kwargs... ) where {T}
219- _check_device_match (du, u, u_update)
220- _check_size_compatibility (L, u_update, u, du)
221- update_coefficients! (L, u_update, p, t; kwargs... )
222- mul! (du, L, u, α, β)
223- return du # Explicitly return du
119+ # In-place with scaling: w = α*(L*v) + β*w
120+ function (L:: AbstractSciMLOperator )(w, v, u, p, t, α, β; kwargs... )
121+ (update_coefficients! (L, u, p, t; kwargs... ); mul! (w, L, v, α, β))
224122end
225123
226- function (L:: AbstractSciMLOperator )(du :: Number , u :: Number , p, t, args... ; kwargs... )
227- msg = """ Nonallocating L(v, u, p, t) type methods are not available for
124+ function (L:: AbstractSciMLOperator )(w :: Number , v :: Number , u , p, t, args... ; kwargs... )
125+ msg = """ Nonallocating L(w, v, u, p, t) type methods are not available for
228126 subtypes of `Number`."""
229127 throw (ArgumentError (msg))
230128end
231129
232- """
233- @private
234-
235- Check that all vectors are on the same device (CPU/GPU).
236- This function is a no-op in the standard implementation but can be
237- extended by packages that provide GPU support.
238- """
239- function _check_device_match (args... )
240- # Default implementation - no device checking in base package
241- # This would be extended by GPU-supporting packages
242- nothing
243- end
244-
245- """
246- @private
247-
248- Verify that the sizes of vectors are compatible with the operator and with each other.
249- """
250- function _check_size_compatibility (L:: AbstractSciMLOperator , u_update, u, du= nothing )
251- # Special case for scalar operators which have size() = ()
252- if L isa AbstractSciMLScalarOperator
253- # Scalar operators can operate on any size inputs
254- # Just check batch dimensions if present
255- if u isa AbstractMatrix && u_update isa AbstractMatrix
256- if size (u, 2 ) != size (u_update, 2 )
257- throw (DimensionMismatch (
258- " Batch dimension of u ($(size (u, 2 )) ) must match batch dimension of u_update ($(size (u_update, 2 )) )" ))
259- end
260- end
261-
262- if du != = nothing && u isa AbstractMatrix && du isa AbstractMatrix
263- if size (u, 2 ) != size (du, 2 )
264- throw (DimensionMismatch (
265- " Batch dimension of u ($(size (u, 2 )) ) must match batch dimension of du ($(size (du, 2 )) )" ))
266- end
267- end
268-
269- return nothing
270- end
271-
272- # For regular operators with dimensions
273- # Verify u_update has compatible size for updating operator
274- if size (u_update, 1 ) != size (L, 2 )
275- throw (DimensionMismatch (
276- " Size of u_update ($(size (u_update, 1 )) ) must match the input dimension of operator ($(size (L, 2 )) )" ))
277- end
278-
279- # Verify u has compatible size for operator application
280- if size (u, 1 ) != size (L, 2 )
281- throw (DimensionMismatch (
282- " Size of u ($(size (u, 1 )) ) must match the input dimension of operator ($(size (L, 2 )) )" ))
283- end
284-
285- # If du is provided, verify it has compatible size for storing the result
286- if du != = nothing && size (du, 1 ) != size (L, 1 )
287- throw (DimensionMismatch (
288- " Size of du ($(size (du, 1 )) ) must match the output dimension of operator ($(size (L, 1 )) )" ))
289- end
290-
291- # Verify batch dimensions match if present
292- if u isa AbstractMatrix && u_update isa AbstractMatrix
293- if size (u, 2 ) != size (u_update, 2 )
294- throw (DimensionMismatch (
295- " Batch dimension of u ($(size (u, 2 )) ) must match batch dimension of u_update ($(size (u_update, 2 )) )" ))
296- end
297- end
298-
299- if du != = nothing && u isa AbstractMatrix && du isa AbstractMatrix
300- if size (u, 2 ) != size (du, 2 )
301- throw (DimensionMismatch (
302- " Batch dimension of u ($(size (u, 2 )) ) must match batch dimension of du ($(size (du, 2 )) )" ))
303- end
304- end
305-
306- nothing
307- end
308-
309130# ##
310131# operator caching interface
311132# ##
0 commit comments