Skip to content

Commit ddbc3ae

Browse files
Merge pull request #74 from vpuri3/msic
small fixes
2 parents 998a420 + 0324cab commit ddbc3ae

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

src/basic.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ function Base.adjoint(α::ScalarOperator) # TODO - test
199199
ScalarOperator(val; update_func=update_func)
200200
end
201201
Base.transpose::ScalarOperator) = α
202+
203+
Base.one(::ScalarOperator{T}) where{T} = ScalarOperator(one(T))
204+
Base.zero(::ScalarOperator{T}) where{T} = ScalarOperator(zero(T))
205+
202206
Base.one(::Type{<:AbstractSciMLOperator}) = ScalarOperator(true)
203207
Base.zero(::Type{<:AbstractSciMLOperator}) = ScalarOperator(false)
204208

@@ -527,6 +531,26 @@ for op in (
527531
@assert size(A, 2) == N
528532
A
529533
end
534+
535+
# null operator
536+
@eval function Base.$op(::NullOperator{N}, A::ComposedOperator) where{N}
537+
@assert size(A, 1) == N
538+
zero(A)
539+
end
540+
541+
@eval function Base.$op(A::ComposedOperator, ::NullOperator{N}) where{N}
542+
@assert size(A, 2) == N
543+
zero(A)
544+
end
545+
546+
# scalar operator
547+
@eval function Base.$op::ScalarOperator, L::ComposedOperator)
548+
ScaledOperator(λ, L)
549+
end
550+
551+
@eval function Base.$op(L::ComposedOperator, λ::ScalarOperator)
552+
ScaledOperator(λ, L)
553+
end
530554
end
531555

532556
Base.convert(::Type{AbstractMatrix}, L::ComposedOperator) = prod(op -> convert(AbstractMatrix, op), L.ops)

src/matrix.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,13 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat)
7979
LinearAlgebra.ldiv!(L::MatrixOperator, u::AbstractVecOrMat) = ldiv!(L.A, u)
8080

8181
""" Diagonal Operator """
82-
DiagonalOperator(u::AbstractVector) = MatrixOperator(Diagonal(u))
82+
function DiagonalOperator(u::AbstractVector; update_func=DEFAULT_UPDATE_FUNC)
83+
function diag_update_func(A, u, p, t)
84+
update_func(A.diag, u, p, t)
85+
A
86+
end
87+
MatrixOperator(Diagonal(u); update_func=diag_update_func)
88+
end
8389
LinearAlgebra.Diagonal(L::MatrixOperator) = MatrixOperator(Diagonal(L.A))
8490

8591
"""

0 commit comments

Comments
 (0)