Skip to content

Commit 27af56b

Browse files
committed
isconstant(L)
1 parent d7c9a38 commit 27af56b

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

src/matrix.jl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ for op in (
3333
:adjoint,
3434
:transpose,
3535
)
36-
@eval function Base.$op(L::MatrixOperator) # TODO - test this thoroughly
37-
MatrixOperator(
38-
$op(L.A);
39-
update_func= (A,u,p,t) -> $op(L.update_func($op(L.A),u,p,t)) # TODO - test
40-
)
36+
@eval function Base.$op(L::MatrixOperator)
37+
if isconstant(L)
38+
MatrixOperator($op(L.A))
39+
else
40+
update_func = (A,u,p,t) -> $op(L.update_func($op(L.A),u,p,t))
41+
MatrixOperator($op(L.A); update_func = update_func)
42+
end
4143
end
4244
end
4345
Base.conj(L::MatrixOperator) = MatrixOperator(
@@ -48,6 +50,7 @@ Base.conj(L::MatrixOperator) = MatrixOperator(
4850
has_adjoint(A::MatrixOperator) = has_adjoint(A.A)
4951
update_coefficients!(L::MatrixOperator,u,p,t) = (L.update_func(L.A,u,p,t); nothing)
5052

53+
getops(L::MatrixOperator) = (L.A)
5154
isconstant(L::MatrixOperator) = L.update_func == DEFAULT_UPDATE_FUNC
5255
Base.iszero(L::MatrixOperator) = iszero(L.A)
5356

@@ -74,8 +77,6 @@ Base.ndims(::Type{<:MatrixOperator{T,AType}}) where{T,AType} = ndims(AType)
7477
ArrayInterfaceCore.issingular(L::MatrixOperator) = ArrayInterfaceCore.issingular(L.A)
7578
Base.copy(L::MatrixOperator) = MatrixOperator(copy(L.A);update_func=L.update_func)
7679

77-
getops(L::MatrixOperator) = (L.A)
78-
7980
# operator application
8081
Base.:*(L::MatrixOperator, u::AbstractVecOrMat) = L.A * u
8182
Base.:\(L::MatrixOperator, u::AbstractVecOrMat) = L.A \ u
@@ -102,10 +103,11 @@ an operator of size `(N, N)` where `N = size(diag, 1)` is the leading length of
102103
`L` then is the elementwise-scaling operation on arrays of `length(u) = length(diag)`
103104
with leading length `size(u, 1) = N`.
104105
"""
105-
function DiagonalOperator(diag::AbstractVector; update_func=DEFAULT_UPDATE_FUNC)
106-
function diag_update_func(A, u, p, t)
107-
update_func(A.diag, u, p, t)
108-
A
106+
function DiagonalOperator(diag::AbstractVector; update_func = DEFAULT_UPDATE_FUNC)
107+
diag_update_func = if update_func == DEFAULT_UPDATE_FUNC
108+
DEFAULT_UPDATE_FUNC
109+
else
110+
(A, u, p, t) -> (update_func(A.diag, u, p, t); A)
109111
end
110112
MatrixOperator(Diagonal(diag); update_func=diag_update_func)
111113
end
@@ -211,7 +213,7 @@ struct AffineOperator{T,AType,BType,bType,cType,F} <: AbstractSciMLOperator{T}
211213
b::bType
212214

213215
cache::cType
214-
update_func::F
216+
update_func::F # updates b
215217

216218
function AffineOperator(A, B, b, cache, update_func)
217219
T = promote_type(eltype.((A,B,b))...)
@@ -231,7 +233,7 @@ end
231233
function AffineOperator(A::Union{AbstractMatrix,AbstractSciMLOperator},
232234
B::Union{AbstractMatrix,AbstractSciMLOperator},
233235
b::AbstractArray;
234-
update_func=DEFAULT_UPDATE_FUNC,
236+
update_func = DEFAULT_UPDATE_FUNC,
235237
)
236238
@assert size(A, 1) == size(B, 1) "Dimension mismatch: A, B don't output vectors
237239
of same size"
@@ -247,7 +249,7 @@ end
247249
L = AddVector(b[; update_func])
248250
L(u) = u + b
249251
"""
250-
function AddVector(b::AbstractVecOrMat; update_func=DEFAULT_UPDATE_FUNC)
252+
function AddVector(b::AbstractVecOrMat; update_func = DEFAULT_UPDATE_FUNC)
251253
N = size(b, 1)
252254
Id = IdentityOperator{N}()
253255

@@ -258,19 +260,20 @@ end
258260
L = AddVector(B, b[; update_func])
259261
L(u) = u + B*b
260262
"""
261-
function AddVector(B, b::AbstractVecOrMat; update_func=DEFAULT_UPDATE_FUNC)
263+
function AddVector(B, b::AbstractVecOrMat; update_func = DEFAULT_UPDATE_FUNC)
262264
N = size(B, 1)
263265
Id = IdentityOperator{N}()
264266

265267
AffineOperator(Id, B, b; update_func=update_func)
266268
end
267269

268270
getops(L::AffineOperator) = (L.A, L.B, L.b)
269-
Base.size(L::AffineOperator) = size(L.A)
270271

271272
update_coefficients!(L::AffineOperator,u,p,t) = (L.update_func(L.b,u,p,t); nothing)
272-
273+
isconstant(L::AffineOperator) = (L.update_func == DEFAULT_UPDATE_FUNC) & all(isconstant, (L.A, L.B))
273274
islinear(::AffineOperator) = false
275+
276+
Base.size(L::AffineOperator) = size(L.A)
274277
Base.iszero(L::AffineOperator) = all(iszero, getops(L))
275278
has_adjoint(L::AffineOperator) = all(has_adjoint, L.ops)
276279
has_mul(L::AffineOperator) = has_mul(L.A)

0 commit comments

Comments
 (0)