Skip to content

Commit 0902fe5

Browse files
Merge pull request #302 from ChrisRackauckas-Claude/add-copy-overloads
Add copy overloads to prevent aliasing for all operators
2 parents 7295fea + fd4d760 commit 0902fe5

File tree

9 files changed

+339
-0
lines changed

9 files changed

+339
-0
lines changed

src/basic.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ end
1616

1717
Base.convert(::Type{AbstractMatrix}, ii::IdentityOperator) = Diagonal(ones(Bool, ii.len))
1818

19+
# Copy method to avoid aliasing - IdentityOperator has no mutable fields, can return self
20+
Base.copy(L::IdentityOperator) = L
21+
1922
# traits
2023
Base.show(io::IO, ii::IdentityOperator) = print(io, "IdentityOperator($(ii.len))")
2124
Base.size(ii::IdentityOperator) = (ii.len, ii.len)
@@ -131,6 +134,9 @@ end
131134

132135
Base.convert(::Type{AbstractMatrix}, nn::NullOperator) = Diagonal(zeros(Bool, nn.len))
133136

137+
# Copy method to avoid aliasing - NullOperator has no mutable fields, can return self
138+
Base.copy(L::NullOperator) = L
139+
134140
# traits
135141
Base.show(io::IO, nn::NullOperator) = print(io, "NullOperator($(nn.len))")
136142
Base.size(nn::NullOperator) = (nn.len, nn.len)
@@ -775,6 +781,15 @@ function update_coefficients(L::ComposedOperator, u, p, t)
775781
end
776782

777783
getops(L::ComposedOperator) = L.ops
784+
785+
# Copy method to avoid aliasing
786+
function Base.copy(L::ComposedOperator)
787+
ComposedOperator(
788+
map(copy, L.ops),
789+
L.cache === nothing ? nothing : deepcopy(L.cache)
790+
)
791+
end
792+
778793
islinear(L::ComposedOperator) = all(islinear, L.ops)
779794
Base.iszero(L::ComposedOperator) = all(iszero, getops(L))
780795
has_adjoint(L::ComposedOperator) = all(has_adjoint, L.ops)
@@ -1015,6 +1030,14 @@ has_ldiv!(L::InvertedOperator) = has_mul!(L.L)
10151030
Base.:*(L::InvertedOperator, u::AbstractVecOrMat) = L.L \ u
10161031
Base.:\(L::InvertedOperator, u::AbstractVecOrMat) = L.L * u
10171032

1033+
# Copy method to avoid aliasing
1034+
function Base.copy(L::InvertedOperator)
1035+
InvertedOperator(
1036+
copy(L.L),
1037+
L.cache === nothing ? nothing : deepcopy(L.cache)
1038+
)
1039+
end
1040+
10181041
function cache_self(L::InvertedOperator, u::AbstractVecOrMat)
10191042
cache = zero(u)
10201043
@reset L.cache = cache

src/batch.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,15 @@ end
107107

108108
getops(L::BatchedDiagonalOperator) = (L.diag,)
109109

110+
# Copy method to avoid aliasing
111+
function Base.copy(L::BatchedDiagonalOperator)
112+
BatchedDiagonalOperator(
113+
copy(L.diag),
114+
L.update_func,
115+
L.update_func!
116+
)
117+
end
118+
110119
function isconstant(L::BatchedDiagonalOperator)
111120
update_func_isconstant(L.update_func) & update_func_isconstant(L.update_func!)
112121
end

src/func.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,23 @@ function iscached(L::FunctionOperator)
439439
L.cache !== nothing
440440
end
441441

442+
# Copy method to avoid aliasing
443+
function Base.copy(L::FunctionOperator)
444+
FunctionOperator(
445+
L.op,
446+
L.op_adjoint,
447+
L.op_inverse,
448+
L.op_adjoint_inverse,
449+
L.traits,
450+
isdefined(L, :u) ? copy(L.u) : nothing,
451+
isdefined(L, :p) ? deepcopy(L.p) : nothing,
452+
L.t,
453+
L.cache === nothing ? nothing : deepcopy(L.cache),
454+
typeof(L).parameters[end-1], # iType
455+
typeof(L).parameters[end] # oType
456+
)
457+
end
458+
442459
# fix method amg bw AbstractArray, AbstractVecOrMat
443460
cache_operator(L::FunctionOperator, u::AbstractArray) = _cache_operator(L, u)
444461
cache_operator(L::FunctionOperator, u::AbstractVecOrMat) = _cache_operator(L, u)

src/left.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ for (op, LType, VType) in ((:adjoint, :AdjointOperator, :AbstractAdjointVecOrMat
106106

107107
@eval getops(L::$LType) = (L.L,)
108108

109+
# Copy method to avoid aliasing
110+
@eval Base.copy(L::$LType) = $LType(copy(L.L))
111+
109112
@eval @forward $LType.L (
110113
# LinearAlgebra
111114
LinearAlgebra.issymmetric,

src/matrix.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,15 @@ function update_coefficients!(L::InvertibleOperator, u, p, t; kwargs...)
405405
end
406406

407407
getops(L::InvertibleOperator) = (L.L, L.F)
408+
409+
# Copy method to avoid aliasing
410+
function Base.copy(L::InvertibleOperator)
411+
InvertibleOperator(
412+
copy(L.L),
413+
copy(L.F)
414+
)
415+
end
416+
408417
islinear(L::InvertibleOperator) = islinear(L.L)
409418
isconvertible(L::InvertibleOperator) = isconvertible(L.L)
410419

@@ -618,6 +627,17 @@ function update_coefficients!(L::AffineOperator, u, p, t; kwargs...)
618627
nothing
619628
end
620629

630+
# Copy method to avoid aliasing
631+
function Base.copy(L::AffineOperator)
632+
AffineOperator(
633+
copy(L.A),
634+
copy(L.B),
635+
copy(L.b);
636+
update_func = L.update_func,
637+
update_func! = L.update_func!
638+
)
639+
end
640+
621641
function isconstant(L::AffineOperator)
622642
update_func_isconstant(L.update_func) &
623643
update_func_isconstant(L.update_func!) &

src/scalar.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,11 @@ function SciMLOperators.update_coefficients(L::ScalarOperator, u, p, t; kwargs..
206206
return ScalarOperator(L.update_func(L.val, u, p, t; kwargs...), L.update_func)
207207
end
208208

209+
# Copy method to avoid aliasing
210+
function Base.copy(L::ScalarOperator)
211+
ScalarOperator(L.val, L.update_func)
212+
end
213+
209214
# Add ScalarOperator specific implementations for the new interface
210215
function::ScalarOperator)(v::AbstractArray, u, p, t; kwargs...)
211216
α = update_coefficients(α, u, p, t; kwargs...)
@@ -313,6 +318,12 @@ function (α::AddedScalarOperator)(
313318
end
314319

315320
getops::AddedScalarOperator) = α.ops
321+
322+
# Copy method to avoid aliasing
323+
function Base.copy(L::AddedScalarOperator)
324+
AddedScalarOperator(map(copy, L.ops))
325+
end
326+
316327
has_ldiv::AddedScalarOperator) = !iszero(convert(Number, α))
317328
has_ldiv!::AddedScalarOperator) = has_ldiv(α)
318329

@@ -432,6 +443,12 @@ function (α::ComposedScalarOperator)(
432443
end
433444

434445
getops::ComposedScalarOperator) = α.ops
446+
447+
# Copy method to avoid aliasing
448+
function Base.copy(L::ComposedScalarOperator)
449+
ComposedScalarOperator(map(copy, L.ops))
450+
end
451+
435452
has_ldiv::ComposedScalarOperator) = all(has_ldiv, α.ops)
436453
has_ldiv!::ComposedScalarOperator) = all(has_ldiv!, α.ops)
437454

@@ -506,6 +523,12 @@ function (α::InvertedScalarOperator)(
506523
mul!(w, α, v, a, b)
507524
end
508525
getops::InvertedScalarOperator) =.λ,)
526+
527+
# Copy method to avoid aliasing
528+
function Base.copy(L::InvertedScalarOperator)
529+
InvertedScalarOperator(copy(L.λ))
530+
end
531+
509532
has_ldiv::InvertedScalarOperator) = has_mul.λ)
510533
has_ldiv!::InvertedScalarOperator) = has_ldiv(α)
511534
#

src/tensor.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,15 @@ function update_coefficients(L::TensorProductOperator, u, p, t)
151151
end
152152

153153
getops(L::TensorProductOperator) = L.ops
154+
155+
# Copy method to avoid aliasing
156+
function Base.copy(L::TensorProductOperator)
157+
TensorProductOperator(
158+
map(copy, L.ops),
159+
L.cache === nothing ? nothing : deepcopy(L.cache)
160+
)
161+
end
162+
154163
islinear(L::TensorProductOperator) = reduce(&, islinear.(L.ops))
155164
isconvertible(::TensorProductOperator) = false
156165
Base.iszero(L::TensorProductOperator) = reduce(|, iszero.(L.ops))

0 commit comments

Comments
 (0)