@@ -35,11 +35,13 @@ for op in (
3535 :adjoint ,
3636 :transpose ,
3737 )
38- @eval function Base. $op (L:: MatrixOperator ) # TODO - test this thoroughly
39- MatrixOperator (
40- $ op (L. A);
41- update_func= (A,u,p,t) -> $ op (L. update_func ($ op (L. A),u,p,t)) # TODO - test
42- )
38+ @eval function Base. $op (L:: MatrixOperator )
39+ if isconstant (L)
40+ MatrixOperator ($ op (L. A))
41+ else
42+ update_func = (A,u,p,t) -> $ op (L. update_func ($ op (L. A),u,p,t))
43+ MatrixOperator ($ op (L. A); update_func = update_func)
44+ end
4345 end
4446end
4547Base. conj (L:: MatrixOperator ) = MatrixOperator (
@@ -50,6 +52,7 @@ Base.conj(L::MatrixOperator) = MatrixOperator(
5052has_adjoint (A:: MatrixOperator ) = has_adjoint (A. A)
5153update_coefficients! (L:: MatrixOperator ,u,p,t) = (L. update_func (L. A,u,p,t); nothing )
5254
55+ getops (L:: MatrixOperator ) = (L. A)
5356isconstant (L:: MatrixOperator ) = L. update_func == DEFAULT_UPDATE_FUNC
5457Base. iszero (L:: MatrixOperator ) = iszero (L. A)
5558
@@ -76,8 +79,6 @@ Base.ndims(::Type{<:MatrixOperator{T,AType}}) where{T,AType} = ndims(AType)
7679ArrayInterfaceCore. issingular (L:: MatrixOperator ) = ArrayInterfaceCore. issingular (L. A)
7780Base. copy (L:: MatrixOperator ) = MatrixOperator (copy (L. A);update_func= L. update_func)
7881
79- getops (L:: MatrixOperator ) = (L. A)
80-
8182# operator application
8283Base.:* (L:: MatrixOperator , u:: AbstractVecOrMat ) = L. A * u
8384Base.:\ (L:: MatrixOperator , u:: AbstractVecOrMat ) = L. A \ u
@@ -104,10 +105,11 @@ an operator of size `(N, N)` where `N = size(diag, 1)` is the leading length of
104105`L` then is the elementwise-scaling operation on arrays of `length(u) = length(diag)`
105106with leading length `size(u, 1) = N`.
106107"""
107- function DiagonalOperator (diag:: AbstractVector ; update_func= DEFAULT_UPDATE_FUNC)
108- function diag_update_func (A, u, p, t)
109- update_func (A. diag, u, p, t)
110- A
108+ function DiagonalOperator (diag:: AbstractVector ; update_func = DEFAULT_UPDATE_FUNC)
109+ diag_update_func = if update_func == DEFAULT_UPDATE_FUNC
110+ DEFAULT_UPDATE_FUNC
111+ else
112+ (A, u, p, t) -> (update_func (A. diag, u, p, t); A)
111113 end
112114 MatrixOperator (Diagonal (diag); update_func= diag_update_func)
113115end
@@ -214,7 +216,7 @@ struct AffineOperator{T,AType,BType,bType,cType,F} <: AbstractSciMLOperator{T}
214216 b:: bType
215217
216218 cache:: cType
217- update_func:: F
219+ update_func:: F # updates b
218220
219221 function AffineOperator (A, B, b, cache, update_func)
220222 T = promote_type (eltype .((A,B,b))... )
234236function AffineOperator (A:: Union{AbstractMatrix,AbstractSciMLOperator} ,
235237 B:: Union{AbstractMatrix,AbstractSciMLOperator} ,
236238 b:: AbstractArray ;
237- update_func= DEFAULT_UPDATE_FUNC,
239+ update_func = DEFAULT_UPDATE_FUNC,
238240 )
239241 @assert size (A, 1 ) == size (B, 1 ) " Dimension mismatch: A, B don't output vectors
240242 of same size"
250252 L = AddVector(b[; update_func])
251253 L(u) = u + b
252254"""
253- function AddVector (b:: AbstractVecOrMat ; update_func= DEFAULT_UPDATE_FUNC)
255+ function AddVector (b:: AbstractVecOrMat ; update_func = DEFAULT_UPDATE_FUNC)
254256 N = size (b, 1 )
255257 Id = IdentityOperator {N} ()
256258
@@ -261,19 +263,20 @@ end
261263 L = AddVector(B, b[; update_func])
262264 L(u) = u + B*b
263265"""
264- function AddVector (B, b:: AbstractVecOrMat ; update_func= DEFAULT_UPDATE_FUNC)
266+ function AddVector (B, b:: AbstractVecOrMat ; update_func = DEFAULT_UPDATE_FUNC)
265267 N = size (B, 1 )
266268 Id = IdentityOperator {N} ()
267269
268270 AffineOperator (Id, B, b; update_func= update_func)
269271end
270272
271273getops (L:: AffineOperator ) = (L. A, L. B, L. b)
272- Base. size (L:: AffineOperator ) = size (L. A)
273274
274275update_coefficients! (L:: AffineOperator ,u,p,t) = (L. update_func (L. b,u,p,t); nothing )
275-
276+ isconstant (L :: AffineOperator ) = (L . update_func == DEFAULT_UPDATE_FUNC) & all (isconstant, (L . A, L . B))
276277islinear (:: AffineOperator ) = false
278+
279+ Base. size (L:: AffineOperator ) = size (L. A)
277280Base. iszero (L:: AffineOperator ) = all (iszero, getops (L))
278281has_adjoint (L:: AffineOperator ) = all (has_adjoint, L. ops)
279282has_mul (L:: AffineOperator ) = has_mul (L. A)
0 commit comments