Skip to content

Commit f94c56d

Browse files
authored
Merge pull request #141 from vpuri3/isconstant
create and export isconstant
2 parents 8b9feb2 + 9c0336e commit f94c56d

File tree

7 files changed

+78
-26
lines changed

7 files changed

+78
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLOperators"
22
uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
33
authors = ["xtalax <[email protected]>"]
4-
version = "0.1.17"
4+
version = "0.1.18"
55

66
[deps]
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/SciMLOperators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export ScalarOperator,
5151
export update_coefficients!,
5252
update_coefficients,
5353

54+
isconstant,
5455
iscached,
5556
cache_operator,
5657

src/interface.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,18 @@ has_ldiv!(L::AbstractSciMLOperator) = false # ldiv!(du, L, u)
107107

108108
### Extra standard assumptions
109109

110-
isconstant(L) = true
110+
isconstant(::Union{
111+
# LinearAlgebra
112+
AbstractMatrix,
113+
UniformScaling,
114+
Factorization,
115+
116+
# Base
117+
Number,
118+
119+
}
120+
) = true
111121
isconstant(L::AbstractSciMLOperator) = all(isconstant, getops(L))
112-
#isconstant(L::AbstractSciMLOperator) = L.update_func = DEFAULT_UPDATE_FUNC
113122

114123
#islinear(L) = false
115124
islinear(::AbstractSciMLOperator) = false

src/matrix.jl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ for op in (
3535
:adjoint,
3636
:transpose,
3737
)
38-
@eval function Base.$op(L::MatrixOperator) # TODO - test this thoroughly
39-
MatrixOperator(
40-
$op(L.A);
41-
update_func= (A,u,p,t) -> $op(L.update_func($op(L.A),u,p,t)) # TODO - test
42-
)
38+
@eval function Base.$op(L::MatrixOperator)
39+
if isconstant(L)
40+
MatrixOperator($op(L.A))
41+
else
42+
update_func = (A,u,p,t) -> $op(L.update_func($op(L.A),u,p,t))
43+
MatrixOperator($op(L.A); update_func = update_func)
44+
end
4345
end
4446
end
4547
Base.conj(L::MatrixOperator) = MatrixOperator(
@@ -50,6 +52,7 @@ Base.conj(L::MatrixOperator) = MatrixOperator(
5052
has_adjoint(A::MatrixOperator) = has_adjoint(A.A)
5153
update_coefficients!(L::MatrixOperator,u,p,t) = (L.update_func(L.A,u,p,t); nothing)
5254

55+
getops(L::MatrixOperator) = (L.A)
5356
isconstant(L::MatrixOperator) = L.update_func == DEFAULT_UPDATE_FUNC
5457
Base.iszero(L::MatrixOperator) = iszero(L.A)
5558

@@ -76,8 +79,6 @@ Base.ndims(::Type{<:MatrixOperator{T,AType}}) where{T,AType} = ndims(AType)
7679
ArrayInterfaceCore.issingular(L::MatrixOperator) = ArrayInterfaceCore.issingular(L.A)
7780
Base.copy(L::MatrixOperator) = MatrixOperator(copy(L.A);update_func=L.update_func)
7881

79-
getops(L::MatrixOperator) = (L.A)
80-
8182
# operator application
8283
Base.:*(L::MatrixOperator, u::AbstractVecOrMat) = L.A * u
8384
Base.:\(L::MatrixOperator, u::AbstractVecOrMat) = L.A \ u
@@ -104,10 +105,11 @@ an operator of size `(N, N)` where `N = size(diag, 1)` is the leading length of
104105
`L` then is the elementwise-scaling operation on arrays of `length(u) = length(diag)`
105106
with leading length `size(u, 1) = N`.
106107
"""
107-
function DiagonalOperator(diag::AbstractVector; update_func=DEFAULT_UPDATE_FUNC)
108-
function diag_update_func(A, u, p, t)
109-
update_func(A.diag, u, p, t)
110-
A
108+
function DiagonalOperator(diag::AbstractVector; update_func = DEFAULT_UPDATE_FUNC)
109+
diag_update_func = if update_func == DEFAULT_UPDATE_FUNC
110+
DEFAULT_UPDATE_FUNC
111+
else
112+
(A, u, p, t) -> (update_func(A.diag, u, p, t); A)
111113
end
112114
MatrixOperator(Diagonal(diag); update_func=diag_update_func)
113115
end
@@ -214,7 +216,7 @@ struct AffineOperator{T,AType,BType,bType,cType,F} <: AbstractSciMLOperator{T}
214216
b::bType
215217

216218
cache::cType
217-
update_func::F
219+
update_func::F # updates b
218220

219221
function AffineOperator(A, B, b, cache, update_func)
220222
T = promote_type(eltype.((A,B,b))...)
@@ -234,7 +236,7 @@ end
234236
function AffineOperator(A::Union{AbstractMatrix,AbstractSciMLOperator},
235237
B::Union{AbstractMatrix,AbstractSciMLOperator},
236238
b::AbstractArray;
237-
update_func=DEFAULT_UPDATE_FUNC,
239+
update_func = DEFAULT_UPDATE_FUNC,
238240
)
239241
@assert size(A, 1) == size(B, 1) "Dimension mismatch: A, B don't output vectors
240242
of same size"
@@ -250,7 +252,7 @@ end
250252
L = AddVector(b[; update_func])
251253
L(u) = u + b
252254
"""
253-
function AddVector(b::AbstractVecOrMat; update_func=DEFAULT_UPDATE_FUNC)
255+
function AddVector(b::AbstractVecOrMat; update_func = DEFAULT_UPDATE_FUNC)
254256
N = size(b, 1)
255257
Id = IdentityOperator{N}()
256258

@@ -261,19 +263,20 @@ end
261263
L = AddVector(B, b[; update_func])
262264
L(u) = u + B*b
263265
"""
264-
function AddVector(B, b::AbstractVecOrMat; update_func=DEFAULT_UPDATE_FUNC)
266+
function AddVector(B, b::AbstractVecOrMat; update_func = DEFAULT_UPDATE_FUNC)
265267
N = size(B, 1)
266268
Id = IdentityOperator{N}()
267269

268270
AffineOperator(Id, B, b; update_func=update_func)
269271
end
270272

271273
getops(L::AffineOperator) = (L.A, L.B, L.b)
272-
Base.size(L::AffineOperator) = size(L.A)
273274

274275
update_coefficients!(L::AffineOperator,u,p,t) = (L.update_func(L.b,u,p,t); nothing)
275-
276+
isconstant(L::AffineOperator) = (L.update_func == DEFAULT_UPDATE_FUNC) & all(isconstant, (L.A, L.B))
276277
islinear(::AffineOperator) = false
278+
279+
Base.size(L::AffineOperator) = size(L.A)
277280
Base.iszero(L::AffineOperator) = all(iszero, getops(L))
278281
has_adjoint(L::AffineOperator) = all(has_adjoint, L.ops)
279282
has_mul(L::AffineOperator) = has_mul(L.A)

test/basic.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ K = 12
3636
@test iscached(Id)
3737
@test size(Id) == (N, N)
3838
@test Id' isa IdentityOperator{N}
39+
@test isconstant(Id)
3940

4041
for op in (
4142
*, \,
@@ -67,6 +68,7 @@ end
6768
@test issquare(Z)
6869
@test islinear(Z)
6970
@test NullOperator(u) isa NullOperator{N}
71+
@test isconstant(Z)
7072
@test zero(A) isa NullOperator{N}
7173
@test convert(AbstractMatrix, Z) == zeros(size(Z))
7274

@@ -105,6 +107,7 @@ end
105107
op = ScaledOperator(α, MatrixOperator(A))
106108

107109
@test op isa ScaledOperator
110+
@test isconstant(op)
108111
@test iscached(op)
109112
@test issquare(op)
110113
@test islinear(op)
@@ -115,6 +118,7 @@ end
115118
opF = factorize(op)
116119

117120
@test opF isa ScaledOperator
121+
@test isconstant(opF)
118122
@test iscached(opF)
119123

120124
@test α * A convert(AbstractMatrix, op) convert(AbstractMatrix, opF)
@@ -148,6 +152,11 @@ end
148152
@test op3 isa AddedOperator
149153
@test op4 isa AddedOperator
150154

155+
@test isconstant(op1)
156+
@test isconstant(op2)
157+
@test isconstant(op3)
158+
@test isconstant(op4)
159+
151160
@test op1 * u op( A*u, B*u)
152161
@test op2 * u op*A*u, B*u)
153162
@test op3 * u op( A*u, β*B*u)
@@ -175,13 +184,16 @@ end
175184
op = (MatrixOperator.((A, B, C))...)
176185

177186
@test op isa ComposedOperator
187+
@test isconstant(op)
188+
178189
@test *(op.ops...) isa ComposedOperator
179190
@test issquare(op)
180191
@test islinear(op)
181192

182193
opF = factorize(op)
183194

184195
@test opF isa ComposedOperator
196+
@test isconstant(opF)
185197
@test issquare(opF)
186198
@test islinear(opF)
187199

@@ -256,6 +268,9 @@ end
256268
AAt = LType(AA)
257269
DDt = LType(DD)
258270

271+
@test isconstant(AAt)
272+
@test isconstant(DDt)
273+
259274
@test AAt.L === AA
260275
@test op(u) isa VType
261276

@@ -280,6 +295,7 @@ end
280295

281296
@test !iscached(Di)
282297
Di = cache_operator(Di, u)
298+
@test isconstant(Di)
283299
@test iscached(Di)
284300

285301
@test issquare(Di)

test/matrix.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ K = 19
2424
@test AA isa MatrixOperator
2525
@test AAt isa MatrixOperator
2626

27+
@test isconstant(AA)
28+
@test isconstant(AAt)
29+
2730
@test issquare(AA)
2831
@test islinear(AA)
2932

@@ -33,6 +36,9 @@ K = 19
3336
@test FF isa InvertibleOperator
3437
@test FFt isa InvertibleOperator
3538

39+
@test isconstant(FF)
40+
@test isconstant(FFt)
41+
3642
@test eachindex(A) === eachindex(AA)
3743
@test eachindex(A') === eachindex(AAt) === eachindex(MatrixOperator(At))
3844

@@ -58,9 +64,11 @@ end
5864
t = rand()
5965

6066
L = MatrixOperator(zeros(N,N);
61-
update_func= (A,u,p,t) -> (A .= p*p'; nothing)
67+
update_func = (A,u,p,t) -> (A .= p*p'; nothing)
6268
)
6369

70+
@test !isconstant(L)
71+
6472
A = p*p'
6573
ans = A * u
6674
@test L(u,p,t) ans
@@ -73,9 +81,10 @@ end
7381
t = rand()
7482

7583
D = DiagonalOperator(zeros(N);
76-
update_func= (diag,u,p,t) -> (diag .= p*t; nothing)
84+
update_func = (diag,u,p,t) -> (diag .= p*t; nothing)
7785
)
7886

87+
@test !isconstant(D)
7988
@test issquare(D)
8089
@test islinear(D)
8190

@@ -91,6 +100,7 @@ end
91100
β = rand()
92101

93102
L = DiagonalOperator(d)
103+
@test isconstant(L)
94104

95105
@test issquare(L)
96106
@test islinear(L)
@@ -114,6 +124,7 @@ end
114124
β = rand()
115125

116126
L = AffineOperator(MatrixOperator(A), MatrixOperator(B), b)
127+
@test isconstant(L)
117128
@test issquare(L)
118129
@test !islinear(L)
119130

@@ -161,9 +172,11 @@ end
161172
t = rand()
162173

163174
L = AffineOperator(A, B, b;
164-
update_func= (b,u,p,t) -> (b .= Diagonal(p*t)*b; nothing)
175+
update_func = (b,u,p,t) -> (b .= Diagonal(p*t)*b; nothing)
165176
)
166177

178+
@test !isconstant(L)
179+
167180
b = Diagonal(p*t)*b
168181
ans = A * u + B * b
169182
@test L(u,p,t) ans
@@ -208,6 +221,12 @@ for square in [false, true] #for K in [1, K]
208221
@test opAB isa TensorProductOperator
209222
@test opABC isa TensorProductOperator
210223

224+
@test isconstant(opAB)
225+
@test isconstant(opABC)
226+
227+
@test islinear(opAB)
228+
@test islinear(opABC)
229+
211230
if square
212231
@test issquare(opAB)
213232
@test issquare(opABC)
@@ -216,16 +235,16 @@ for square in [false, true] #for K in [1, K]
216235
@test !issquare(opABC)
217236
end
218237

219-
@test islinear(opAB)
220-
@test islinear(opABC)
221-
222238
@test AB convert(AbstractMatrix, opAB)
223239
@test ABC convert(AbstractMatrix, opABC)
224240

225241
# factorization tests
226242
opAB_F = factorize(opAB)
227243
opABC_F = factorize(opABC)
228244

245+
@test isconstant(opAB_F)
246+
@test isconstant(opABC_F)
247+
229248
@test opAB_F isa TensorProductOperator
230249
@test opABC_F isa TensorProductOperator
231250

test/scalar.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ K = 12
2222
@test convert(ScalarOperator, a) isa ScalarOperator
2323

2424
@test size(α) == ()
25+
@test isconstant(α)
2526

2627
v=copy(u); @test lmul!(α, u) v * x
2728
v=copy(u); @test rmul!(u, α) x * v
@@ -68,6 +69,9 @@ end
6869
α = ScalarOperator(0.0; update_func=(a,u,p,t) -> p)
6970
β = ScalarOperator(0.0; update_func=(a,u,p,t) -> t)
7071

72+
@test !isconstant(α)
73+
@test !isconstant(β)
74+
7175
@test α(u,p,t) p * u
7276
@test α(v,u,p,t) p * u
7377

0 commit comments

Comments
 (0)