@@ -33,11 +33,13 @@ for op in (
3333 :adjoint ,
3434 :transpose ,
3535 )
36- @eval function Base. $op (L:: MatrixOperator ) # TODO - test this thoroughly
37- MatrixOperator (
38- $ op (L. A);
39- update_func= (A,u,p,t) -> $ op (L. update_func ($ op (L. A),u,p,t)) # TODO - test
40- )
36+ @eval function Base. $op (L:: MatrixOperator )
37+ if isconstant (L)
38+ MatrixOperator ($ op (L. A))
39+ else
40+ update_func = (A,u,p,t) -> $ op (L. update_func ($ op (L. A),u,p,t))
41+ MatrixOperator ($ op (L. A); update_func = update_func)
42+ end
4143 end
4244end
4345Base. conj (L:: MatrixOperator ) = MatrixOperator (
@@ -48,6 +50,7 @@ Base.conj(L::MatrixOperator) = MatrixOperator(
4850has_adjoint (A:: MatrixOperator ) = has_adjoint (A. A)
4951update_coefficients! (L:: MatrixOperator ,u,p,t) = (L. update_func (L. A,u,p,t); nothing )
5052
53+ getops (L:: MatrixOperator ) = (L. A)
5154isconstant (L:: MatrixOperator ) = L. update_func == DEFAULT_UPDATE_FUNC
5255Base. iszero (L:: MatrixOperator ) = iszero (L. A)
5356
@@ -74,8 +77,6 @@ Base.ndims(::Type{<:MatrixOperator{T,AType}}) where{T,AType} = ndims(AType)
7477ArrayInterfaceCore. issingular (L:: MatrixOperator ) = ArrayInterfaceCore. issingular (L. A)
7578Base. copy (L:: MatrixOperator ) = MatrixOperator (copy (L. A);update_func= L. update_func)
7679
77- getops (L:: MatrixOperator ) = (L. A)
78-
7980# operator application
8081Base.:* (L:: MatrixOperator , u:: AbstractVecOrMat ) = L. A * u
8182Base.:\ (L:: MatrixOperator , u:: AbstractVecOrMat ) = L. A \ u
@@ -102,10 +103,11 @@ an operator of size `(N, N)` where `N = size(diag, 1)` is the leading length of
102103`L` then is the elementwise-scaling operation on arrays of `length(u) = length(diag)`
103104with leading length `size(u, 1) = N`.
104105"""
105- function DiagonalOperator (diag:: AbstractVector ; update_func= DEFAULT_UPDATE_FUNC)
106- function diag_update_func (A, u, p, t)
107- update_func (A. diag, u, p, t)
108- A
106+ function DiagonalOperator (diag:: AbstractVector ; update_func = DEFAULT_UPDATE_FUNC)
107+ diag_update_func = if update_func == DEFAULT_UPDATE_FUNC
108+ DEFAULT_UPDATE_FUNC
109+ else
110+ (A, u, p, t) -> (update_func (A. diag, u, p, t); A)
109111 end
110112 MatrixOperator (Diagonal (diag); update_func= diag_update_func)
111113end
@@ -211,7 +213,7 @@ struct AffineOperator{T,AType,BType,bType,cType,F} <: AbstractSciMLOperator{T}
211213 b:: bType
212214
213215 cache:: cType
214- update_func:: F
216+ update_func:: F # updates b
215217
216218 function AffineOperator (A, B, b, cache, update_func)
217219 T = promote_type (eltype .((A,B,b))... )
231233function AffineOperator (A:: Union{AbstractMatrix,AbstractSciMLOperator} ,
232234 B:: Union{AbstractMatrix,AbstractSciMLOperator} ,
233235 b:: AbstractArray ;
234- update_func= DEFAULT_UPDATE_FUNC,
236+ update_func = DEFAULT_UPDATE_FUNC,
235237 )
236238 @assert size (A, 1 ) == size (B, 1 ) " Dimension mismatch: A, B don't output vectors
237239 of same size"
247249 L = AddVector(b[; update_func])
248250 L(u) = u + b
249251"""
250- function AddVector (b:: AbstractVecOrMat ; update_func= DEFAULT_UPDATE_FUNC)
252+ function AddVector (b:: AbstractVecOrMat ; update_func = DEFAULT_UPDATE_FUNC)
251253 N = size (b, 1 )
252254 Id = IdentityOperator {N} ()
253255
@@ -258,19 +260,20 @@ end
258260 L = AddVector(B, b[; update_func])
259261 L(u) = u + B*b
260262"""
261- function AddVector (B, b:: AbstractVecOrMat ; update_func= DEFAULT_UPDATE_FUNC)
263+ function AddVector (B, b:: AbstractVecOrMat ; update_func = DEFAULT_UPDATE_FUNC)
262264 N = size (B, 1 )
263265 Id = IdentityOperator {N} ()
264266
265267 AffineOperator (Id, B, b; update_func= update_func)
266268end
267269
268270getops (L:: AffineOperator ) = (L. A, L. B, L. b)
269- Base. size (L:: AffineOperator ) = size (L. A)
270271
271272update_coefficients! (L:: AffineOperator ,u,p,t) = (L. update_func (L. b,u,p,t); nothing )
272-
273+ isconstant (L :: AffineOperator ) = (L . update_func == DEFAULT_UPDATE_FUNC) & all (isconstant, (L . A, L . B))
273274islinear (:: AffineOperator ) = false
275+
276+ Base. size (L:: AffineOperator ) = size (L. A)
274277Base. iszero (L:: AffineOperator ) = all (iszero, getops (L))
275278has_adjoint (L:: AffineOperator ) = all (has_adjoint, L. ops)
276279has_mul (L:: AffineOperator ) = has_mul (L. A)
0 commit comments