Skip to content

Commit 7fa9116

Browse files
Merge pull request #201 from vpuri3/func
Refactor SciMLOps after #200
2 parents bc5a1fe + 8555b5f commit 7fa9116

File tree

13 files changed

+417
-336
lines changed

13 files changed

+417
-336
lines changed

src/SciMLOperators.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ abstract type AbstractSciMLScalarOperator{T} <: AbstractSciMLOperator{T} end
6969
include("utils.jl")
7070
include("interface.jl")
7171
include("left.jl")
72-
include("multidim.jl")
7372

7473
include("scalar.jl")
7574
include("matrix.jl")
@@ -78,7 +77,10 @@ include("batch.jl")
7877
include("func.jl")
7978
include("tensor.jl")
8079

81-
export ScalarOperator,
80+
export
81+
IdentityOperator,
82+
NullOperator,
83+
ScalarOperator,
8284
MatrixOperator,
8385
DiagonalOperator,
8486
InvertibleOperator,

src/basic.jl

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ struct IdentityOperator <: AbstractSciMLOperator{Bool}
88
end
99

1010
# constructors
11-
IdentityOperator(u::AbstractArray) = IdentityOperator(size(u,1))
12-
1311
function Base.one(L::AbstractSciMLOperator)
1412
@assert issquare(L)
1513
N = size(L, 1)
@@ -105,8 +103,6 @@ struct NullOperator <: AbstractSciMLOperator{Bool}
105103
end
106104

107105
# constructors
108-
NullOperator(u::AbstractArray) = NullOperator(size(u,1))
109-
110106
function Base.zero(L::AbstractSciMLOperator)
111107
@assert issquare(L)
112108
N = size(L, 1)
@@ -636,6 +632,7 @@ function cache_self(L::ComposedOperator, u::AbstractVecOrMat)
636632

637633
K = size(u, 2)
638634
cache = (zero(u),)
635+
639636
for i in reverse(2:length(L.ops))
640637
op = L.ops[i]
641638

@@ -644,8 +641,8 @@ function cache_self(L::ComposedOperator, u::AbstractVecOrMat)
644641

645642
T = if op isa FunctionOperator #
646643
# FunctionOperator isn't guaranteed to play by the rules of
647-
# `promote_type`. For example, an rFFT is a complex operation
648-
# that accepts and complex vector and returns a real one.
644+
# `promote_type`. For example, an irFFT is a complex operation
645+
# that accepts complex vector and returns ones.
649646
op.traits.eltypes[2]
650647
else
651648
promote_type(eltype.((op, cache[1]))...)
@@ -672,8 +669,8 @@ function cache_internals(L::ComposedOperator, u::AbstractVecOrMat)
672669
end
673670

674671
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ComposedOperator, u::AbstractVecOrMat)
675-
@assert iscached(L) "cache needs to be set up for operator of type $(typeof(L)).
676-
set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)"
672+
@assert iscached(L) """cache needs to be set up for operator of type
673+
$L. Set up cache by calling `cache_operator(L, u)`"""
677674

678675
vecs = (v, L.cache[1:end-1]..., u)
679676
for i in reverse(1:length(L.ops))
@@ -683,8 +680,8 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ComposedOperator, u::Abstrac
683680
end
684681

685682
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ComposedOperator, u::AbstractVecOrMat, α, β)
686-
@assert iscached(L) "cache needs to be set up for operator of type $(typeof(L)).
687-
set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)"
683+
@assert iscached(L) """cache needs to be set up for operator of type
684+
$L. Set up cache by calling `cache_operator(L, u)`."""
688685

689686
cache = L.cache[end]
690687
copy!(cache, v)
@@ -695,8 +692,8 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ComposedOperator, u::Abstrac
695692
end
696693

697694
function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::ComposedOperator, u::AbstractVecOrMat)
698-
@assert iscached(L) "cache needs to be set up for operator of type $(typeof(L)).
699-
set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)"
695+
@assert iscached(L) """cache needs to be set up for operator of type
696+
$L. Set up cache by calling `cache_operator(L, u)`."""
700697

701698
vecs = (u, reverse(L.cache[1:end-1])..., v)
702699
for i in 1:length(L.ops)
@@ -801,8 +798,8 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::InvertedOperator, u::Abstrac
801798
end
802799

803800
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::InvertedOperator, u::AbstractVecOrMat, α, β)
804-
@assert iscached(L) "cache needs to be set up for operator of type $(typeof(L)).
805-
set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)"
801+
@assert iscached(L) """cache needs to be set up for operator of type
802+
$L. Set up cache by calling `cache_operator(L, u)`."""
806803

807804
copy!(L.cache, v)
808805
ldiv!(v, L.L, u)
@@ -815,8 +812,8 @@ function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertedOperator, u::Abstra
815812
end
816813

817814
function LinearAlgebra.ldiv!(L::InvertedOperator, u::AbstractVecOrMat)
818-
@assert iscached(L) "cache needs to be set up for operator of type $(typeof(L)).
819-
set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)"
815+
@assert iscached(L) """cache needs to be set up for operator of type
816+
$L. Set up cache by calling `cache_operator(L, u)`."""
820817

821818
copy!(L.cache, u)
822819
mul!(u, L.L, L.cache)

0 commit comments

Comments
 (0)