Skip to content

Commit 8c981be

Browse files
authored
Merge pull request #175 from vpuri3/resize
overload Base.resize<bang>
2 parents e67e6f3 + 8e85ea5 commit 8c981be

File tree

8 files changed

+133
-3
lines changed

8 files changed

+133
-3
lines changed

src/basic.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Base.size(ii::IdentityOperator) = (ii.len, ii.len)
2121
Base.adjoint(A::IdentityOperator) = A
2222
Base.transpose(A::IdentityOperator) = A
2323
Base.conj(A::IdentityOperator) = A
24+
2425
LinearAlgebra.opnorm(::IdentityOperator, p::Real=2) = true
2526
for pred in (
2627
:issymmetric, :ishermitian, :isposdef,
@@ -230,6 +231,7 @@ for op in (
230231
@eval Base.$op(L::ScaledOperator) = ScaledOperator($op(L.λ), $op(L.L))
231232
end
232233
Base.conj(L::ScaledOperator) = conj(L.λ) * conj(L.L)
234+
Base.resize!(L::ScaledOperator, n::Integer) = (resize!(L.L, n); L)
233235
LinearAlgebra.opnorm(L::ScaledOperator, p::Real=2) = abs(L.λ) * opnorm(L.L, p)
234236

235237
getops(L::ScaledOperator) = (L.λ, L.L,)
@@ -377,6 +379,12 @@ for op in (
377379
@eval Base.$op(L::AddedOperator) = AddedOperator($op.(L.ops)...)
378380
end
379381
Base.conj(L::AddedOperator) = AddedOperator(conj.(L.ops))
382+
function Base.resize!(L::AddedOperator, n::Integer)
383+
for op in L.ops
384+
resize!(op, n)
385+
end
386+
L
387+
end
380388

381389
getops(L::AddedOperator) = L.ops
382390
islinear(L::AddedOperator) = all(islinear, getops(L))
@@ -509,6 +517,19 @@ for op in (
509517
)
510518
end
511519
Base.conj(L::ComposedOperator) = ComposedOperator(conj.(L.ops); cache=L.cache)
520+
function Base.resize!(L::ComposedOperator, n::Integer)
521+
522+
for op in L.ops
523+
resize!(op, n)
524+
end
525+
526+
for v in L.cache
527+
resize!(v, n)
528+
end
529+
530+
L
531+
end
532+
512533
LinearAlgebra.opnorm(L::ComposedOperator) = prod(opnorm, L.ops)
513534

514535
getops(L::ComposedOperator) = L.ops
@@ -667,6 +688,13 @@ Base.size(L::InvertedOperator) = size(L.L) |> reverse
667688
Base.transpose(L::InvertedOperator) = InvertedOperator(transpose(L.L); cache = iscached(L) ? L.cache' : nothing)
668689
Base.adjoint(L::InvertedOperator) = InvertedOperator(adjoint(L.L); cache = iscached(L) ? L.cache' : nothing)
669690
Base.conj(L::InvertedOperator) = InvertedOperator(conj(L.L); cache=L.cache)
691+
function Base.resize!(L::InvertedOperator, n::Integer)
692+
693+
resize!(L.L, n)
694+
resize!(L.cache, n)
695+
696+
L
697+
end
670698

671699
getops(L::InvertedOperator) = (L.L,)
672700
islinear(L::InvertedOperator) = islinear(L.L)

src/func.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,23 @@ function Base.inv(L::FunctionOperator)
291291
)
292292
end
293293

294+
function Base.resize!(L::FunctionOperator, n::Integer)
295+
296+
for op in getops(L)
297+
if static_hasmethod(resize!, typeof((op, n)))
298+
resize!(op, n)
299+
end
300+
end
301+
302+
for v in L.cache
303+
resize!(v, n)
304+
end
305+
306+
L.traits = (; L.traits..., size = (n, n),)
307+
308+
L
309+
end
310+
294311
function LinearAlgebra.opnorm(L::FunctionOperator, p)
295312
L.traits.opnorm === nothing && error("""
296313
M.opnorm is nothing, please define opnorm as a function that takes one

src/interface.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,10 @@ function Base.getindex(L::AbstractSciMLOperator, I::Vararg{Int, N}) where {N}
243243
convert(AbstractMatrix,L)[I...]
244244
end
245245

246+
function Base.resize!(L::AbstractSciMLOperator, n::Integer)
247+
throw(MethodError(resize!, typeof.((L, n))))
248+
end
249+
246250
LinearAlgebra.exp(L::AbstractSciMLOperator) = exp(Matrix(L))
247251
LinearAlgebra.opnorm(L::AbstractSciMLOperator, p::Real=2) = opnorm(convert(AbstractMatrix,L), p)
248252
for pred in (

src/left.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ end
3232

3333
function LinearAlgebra.ldiv!(u::AbstractVecOrMat, L::AbstractSciMLOperator)
3434
op = (u isa Transpose) ? transpose : adjoint
35-
ldiv!(op(v), op(L), op(u))
36-
v
35+
ldiv!(op(L), op(u))
36+
u
3737
end
3838

3939
###
@@ -60,7 +60,7 @@ AbstractAdjointVecOrMat = Adjoint{ T,<:AbstractVecOrMat} where{T}
6060
AbstractTransposedVecOrMat = Transpose{T,<:AbstractVecOrMat} where{T}
6161

6262
has_adjoint(::AdjointOperator) = true
63-
#has_adjoint(::TransposedOperator) = ??
63+
has_adjoint(L::TransposedOperator) = isreal(L) & has_adjoint(L.L)
6464

6565
islinear(L::AdjointOperator) = islinear(L.L)
6666
islinear(L::TransposedOperator) = islinear(L.L)
@@ -79,6 +79,7 @@ for (op, LType, VType) in (
7979

8080
# traits
8181
@eval Base.size(L::$LType) = size(L.L) |> reverse
82+
@eval Base.resize!(L::$LType, n::Integer) = (resize!(L.L, n); L)
8283
@eval Base.$op(L::$LType) = L.L
8384

8485
@eval getops(L::$LType) = (L.L,)

src/matrix.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ Base.size(L::InvertibleOperator) = size(L.F)
172172
Base.transpose(L::InvertibleOperator) = InvertibleOperator(transpose(L.F))
173173
Base.adjoint(L::InvertibleOperator) = InvertibleOperator(L.F')
174174
Base.conj(L::InvertibleOperator) = InvertibleOperator(conj(L.F))
175+
Base.resize!(L::InvertibleOperator, n::Integer) = (resize!(L.F, n); L)
175176
LinearAlgebra.opnorm(L::InvertibleOperator{T}, p=2) where{T} = one(T) / opnorm(L.F)
176177
LinearAlgebra.issuccess(L::InvertibleOperator) = issuccess(L.F)
177178

@@ -278,6 +279,15 @@ islinear(::AffineOperator) = false
278279

279280
Base.size(L::AffineOperator) = size(L.A)
280281
Base.iszero(L::AffineOperator) = all(iszero, getops(L))
282+
function Base.resize!(L::AffineOperator, n::Integer)
283+
284+
resize!(L.A, n)
285+
resize!(L.B, n)
286+
resize!(L.b, n)
287+
288+
L
289+
end
290+
281291
has_adjoint(L::AffineOperator) = all(has_adjoint, L.ops)
282292
has_mul(L::AffineOperator) = has_mul(L.A)
283293
has_mul!(L::AffineOperator) = has_mul!(L.A)

test/basic.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ K = 12
3737
@test size(Id) == (N, N)
3838
@test Id' isa IdentityOperator
3939
@test isconstant(Id)
40+
@test_throws MethodError resize!(Id, N)
4041

4142
for op in (
4243
*, \,
@@ -69,6 +70,8 @@ end
6970
@test islinear(Z)
7071
@test NullOperator(u) isa NullOperator
7172
@test isconstant(Z)
73+
@test_throws MethodError resize!(Z, N)
74+
7275
@test zero(A) isa NullOperator
7376
@test convert(AbstractMatrix, Z) == zeros(size(Z))
7477

test/matrix.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ K = 19
3939
@test isconstant(FF)
4040
@test isconstant(FFt)
4141

42+
@test_throws MethodError resize!(AA, N)
43+
4244
@test eachindex(A) === eachindex(AA)
4345
@test eachindex(A') === eachindex(AAt) === eachindex(MatrixOperator(At))
4446

@@ -101,6 +103,7 @@ end
101103

102104
L = DiagonalOperator(d)
103105
@test isconstant(L)
106+
@test_throws MethodError resize!(L, N)
104107

105108
@test issquare(L)
106109
@test islinear(L)

test/total.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,68 @@ end
9797
v=rand(N2,K); @test mul!(v, op, u) op * u
9898
v=rand(N2,K); w=copy(v); @test mul!(v, op, u, α, β) α*(op * u) + β * w
9999
end
100+
101+
@testset "Resize! test" begin
102+
M1 = 4
103+
M2 = 12
104+
105+
u = rand(N)
106+
u1 = rand(M1)
107+
u2 = rand(M2)
108+
109+
f(u, p, t) = 2 * u
110+
f(v, u, p, t) = (copy!(v, u); lmul!(2, v))
111+
112+
fi(u, p, t) = 0.5 * u
113+
fi(v, u, p, t) = (copy!(v, u); lmul!(0.5, v))
114+
115+
F = FunctionOperator(f, u, u; islinear = true, op_inverse = fi, issymmetric = true)
116+
117+
multest(L, u) = @test mul!(zero(u), L, u) L * u
118+
119+
function multest(L::SciMLOperators.AdjointOperator, u)
120+
@test mul!(adjoint(zero(u)), adjoint(u), L) adjoint(u) * L
121+
end
122+
123+
function multest(L::SciMLOperators.TransposedOperator, u)
124+
@test mul!(transpose(zero(u)), transpose(u), L) transpose(u) * L
125+
end
126+
127+
function multest(L::SciMLOperators.InvertedOperator, u)
128+
@test ldiv!(zero(u), L, u) L \ u
129+
end
130+
131+
for (L, LT) in (
132+
(F, FunctionOperator),
133+
(F + F, SciMLOperators.AddedOperator),
134+
(F * 2, SciMLOperators.ScaledOperator),
135+
(F F, SciMLOperators.ComposedOperator),
136+
(AffineOperator(F, F, u), AffineOperator),
137+
(SciMLOperators.AdjointOperator(F), SciMLOperators.AdjointOperator),
138+
(SciMLOperators.TransposedOperator(F), SciMLOperators.TransposedOperator),
139+
(SciMLOperators.InvertedOperator(F), SciMLOperators.InvertedOperator),
140+
(SciMLOperators.InvertibleOperator(F), SciMLOperators.InvertibleOperator),
141+
)
142+
143+
@info "$LT"
144+
145+
L = deepcopy(L)
146+
L = cache_operator(L, u)
147+
148+
@test L isa LT
149+
@test size(L) == (N, N)
150+
multest(L, u)
151+
152+
resize!(L, M1); @test size(L) == (M1, M1)
153+
multest(L, u1)
154+
155+
resize!(L, M2); @test size(L) == (M2, M2)
156+
multest(L, u2)
157+
158+
end
159+
160+
# InvertedOperator
161+
# AffineOperator
162+
# FunctionOperator
163+
end
100164
#

0 commit comments

Comments
 (0)