Skip to content

Commit 03d6d39

Browse files
committed
Merge branch 'master' into isconcrete
2 parents 5adda29 + 303622a commit 03d6d39

File tree

4 files changed

+173
-100
lines changed

4 files changed

+173
-100
lines changed

docs/src/tutorials/fftw.md

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,22 @@ x = range(start=-L/2, stop=L/2-dx, length=n) |> Array
1919
u = @. sin(5x)cos(7x);
2020
du = @. 5cos(5x)cos(7x) - 7sin(5x)sin(7x);
2121
22-
k = rfftfreq(n, 2π*n/L) |> Array
23-
m = length(k)
24-
transform = plan_rfft(x)
22+
k = rfftfreq(n, 2π*n/L) |> Array
23+
m = length(k)
24+
P = plan_rfft(x)
25+
26+
F = FunctionOperator(fwd, x, im*k;
27+
T=ComplexF64,
2528
26-
T = FunctionOperator((du,u,p,t) -> mul!(du, transform, u), x, im*k;
27-
isinplace=true,
28-
T=ComplexF64,
29+
op_adjoint = bwd,
30+
op_inverse = bwd,
31+
op_adjoint_inverse = fwd,
2932
30-
op_adjoint = (du,u,p,t) -> ldiv!(du, transform, u),
31-
op_inverse = (du,u,p,t) -> ldiv!(du, transform, u),
32-
op_adjoint_inverse = (du,u,p,t) -> ldiv!(du, transform, u),
33-
)
33+
islinear=true,
34+
)
3435
3536
ik = im * DiagonalOperator(k)
36-
Dx = T \ ik * T
37+
Dx = F \ ik * F
3738
3839
Dx = cache_operator(Dx, x)
3940
@@ -79,18 +80,17 @@ Now we are ready to define our wrapper for the FFT object. To `FunctionOperator`
7980
pass the in-place forward application of the transform,
8081
`(du,u,p,t) -> mul!(du, transform, u)`, its inverse application,
8182
`(du,u,p,t) -> ldiv!(du, transform, u)`, as well as input and output prototype vectors.
82-
We also set the flag `isinplace` to `true` to signal that we intend to use the operator
83-
in a non-allocating way, and pass in the element-type and size of the operator.
8483

8584
```
86-
T = FunctionOperator((du,u,p,t) -> mul!(du, transform, u), x, im*k;
87-
isinplace=true,
88-
T=ComplexF64,
89-
90-
op_adjoint = (du,u,p,t) -> ldiv!(du, transform, u),
91-
op_inverse = (du,u,p,t) -> ldiv!(du, transform, u),
92-
op_adjoint_inverse = (du,u,p,t) -> ldiv!(du, transform, u),
93-
)
85+
F = FunctionOperator(fwd, x, im*k;
86+
T=ComplexF64,
87+
88+
op_adjoint = bwd,
89+
op_inverse = bwd,
90+
op_adjoint_inverse = fwd,
91+
92+
islinear=true,
93+
)
9494
```
9595

9696
After wrapping the FFT with `FunctionOperator`, we are ready to compose it with other
@@ -100,7 +100,7 @@ both in-place, and out-of-place by comparing its output to the analytical deriva
100100

101101
```
102102
ik = im * DiagonalOperator(k)
103-
Dx = T \ ik * T
103+
Dx = F \ ik * F
104104
105105
@show ≈(Dx * u, du; atol=1e-8)
106106
@show ≈(mul!(copy(u), Dx, u), du; atol=1e-8)

src/func.jl

Lines changed: 122 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ function FunctionOperator(op,
175175
msg = """`FunctionOperator` constructed with `batch = true` only
176176
accepts `AbstractVecOrMat` types with
177177
`size(L, 2) == size(u, 1)`."""
178-
ArgumentError(msg) |> throw
178+
throw(ArgumentError(msg))
179179
end
180180

181181
if input isa AbstractMatrix
@@ -186,7 +186,7 @@ function FunctionOperator(op,
186186
array, $(typeof(input)), has size $(size(input)), whereas
187187
output array, $(typeof(output)), has size
188188
$(size(output))."""
189-
ArgumentError(msg) |> throw
189+
throw(ArgumentError(msg))
190190
end
191191
end
192192
end
@@ -343,14 +343,14 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray)
343343
if !isa(u, AbstractVecOrMat)
344344
msg = """$L constructed with `batch = true` only accepts
345345
`AbstractVecOrMat` types with `size(L, 2) == size(u, 1)`."""
346-
ArgumentError(msg) |> throw
346+
throw(ArgumentError(msg))
347347
end
348348

349349
if size(L, 2) != size(u, 1)
350350
msg = """Second dimension of $L of size $(size(L))
351351
is not consistent with first dimension of input array `u`
352352
of size $(size(u))."""
353-
DimensionMismatch(msg) |> throw
353+
throw(DimensionMismatch(msg))
354354
end
355355

356356
M = size(L, 1)
@@ -491,7 +491,7 @@ function Base.resize!(L::FunctionOperator, n::Integer)
491491
if length(L.traits.sizes[1]) != 1
492492
msg = """`Base.resize!` is only supported by $L whose input/output
493493
arrays are `AbstractVector`s."""
494-
MethodError(msg) |> throw
494+
throw(MethodError(msg))
495495
end
496496

497497
for op in getops(L)
@@ -540,131 +540,187 @@ has_ldiv(L::FunctionOperator{iip}) where{iip} = !(L.op_inverse isa Nothing)
540540
has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothing)
541541

542542
function _sizecheck(L::FunctionOperator, u, v)
543-
543+
sizes = L.traits.sizes
544544
if L.traits.batch
545545
if !isnothing(u)
546546
if !isa(u, AbstractVecOrMat)
547547
msg = """$L constructed with `batch = true` only
548548
accept input arrays that are `AbstractVecOrMat`s with
549-
`size(L, 2) == size(u, 1)`."""
550-
ArgumentError(msg) |> throw
549+
`size(L, 2) == size(u, 1)`. Recieved $(typeof(u))."""
550+
throw(ArgumentError(msg))
551551
end
552552

553-
if size(u) != L.traits.sizes[1]
554-
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
555-
Recievd array of size $(size(u))."""
556-
DimensionMismatch(msg) |> throw
553+
if size(L, 2) != size(u, 1)
554+
msg = """$L accepts input `AbstractVecOrMat`s of size
555+
($(size(L, 2)), K). Recievd array of size $(size(u))."""
556+
throw(DimensionMismatch(msg))
557557
end
558-
end
558+
end # u
559559

560560
if !isnothing(v)
561-
if size(v) != L.traits.sizes[2]
562-
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
563-
Recievd array of size $(size(v))."""
564-
DimensionMismatch(msg) |> throw
561+
if !isa(v, AbstractVecOrMat)
562+
msg = """$L constructed with `batch = true` only
563+
returns output arrays that are `AbstractVecOrMat`s with
564+
`size(L, 1) == size(v, 1)`. Recieved $(typeof(v))."""
565+
throw(ArgumentError(msg))
565566
end
566-
end
567+
568+
if size(L, 1) != size(v, 1)
569+
msg = """$L accepts output `AbstractVecOrMat`s of size
570+
($(size(L, 1)), K). Recievd array of size $(size(v))."""
571+
throw(DimensionMismatch(msg))
572+
end
573+
end # v
574+
575+
if !isnothing(u) & !isnothing(v)
576+
if size(u, 2) != size(v, 2)
577+
msg = """input array $u, and output array, $v, must have the
578+
same batch size (i.e. length of second dimension). Got
579+
$(size(u)), $(size(v)). If you encounter this error during
580+
an in-place evaluation (`LinearAlgebra.mul!`, `ldiv!`),
581+
ensure that the operator $L has been cached with an input
582+
array of the correct size. Do so by calling
583+
`L = cache_operator(L, u)`."""
584+
throw(DimensionMismatch(msg))
585+
end
586+
end # u, v
587+
567588
else # !batch
589+
568590
if !isnothing(u)
569-
if size(u) != L.traits.sizes[1]
570-
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
571-
Recievd array of size $(size(u))."""
572-
DimensionMismatch(msg) |> throw
591+
if size(u) (sizes[1], tuple(size(L, 2)),)
592+
msg = """$L recievd input array of size $(size(u)), but only
593+
accepts input arrays of size $(sizes[1]), or vectors like
594+
`vec(u)` of size $(tuple(prod(sizes[1])))."""
595+
throw(DimensionMismatch(msg))
573596
end
574-
end
597+
end # u
575598

576599
if !isnothing(v)
577-
if size(v) != L.traits.sizes[2]
578-
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
579-
Recievd array of size $(size(v))."""
580-
DimensionMismatch(msg) |> throw
600+
if size(v) (sizes[2], tuple(size(L, 1)),)
601+
msg = """$L recievd output array of size $(size(v)), but only
602+
accepts output arrays of size $(sizes[2]), or vectors like
603+
`vec(u)` of size $(tuple(prod(sizes[2])))"""
604+
throw(DimensionMismatch(msg))
581605
end
582-
end
606+
end # v
583607
end # batch
584608

585609
return
586610
end
587611

588-
# operator application
589-
function Base.:*(L::FunctionOperator{iip,true}, u::AbstractArray) where{iip}
590-
_sizecheck(L, u, nothing)
612+
function _unvec(L::FunctionOperator, u, v)
613+
if L.traits.batch
614+
return u, v, false
615+
else
616+
sizes = L.traits.sizes
591617

592-
L.op(u, L.p, L.t; L.traits.kwargs...)
593-
end
618+
# no need to vec since expected input/output are AbstractVectors
619+
if length(sizes[1]) == 1
620+
return u, v, false
621+
end
594622

595-
function Base.:\(L::FunctionOperator{iip,true}, u::AbstractArray) where{iip}
596-
_sizecheck(L, nothing, u)
623+
vec_u = isnothing(u) ? false : size(u) != sizes[1]
624+
vec_v = isnothing(v) ? false : size(v) != sizes[2]
597625

598-
L.op_inverse(u, L.p, L.t; L.traits.kwargs...)
599-
end
626+
if !isnothing(u) & !isnothing(v)
627+
if (vec_u & !vec_v) | (!vec_u & vec_v)
628+
msg = """Input / output to $L can either be of sizes
629+
$(sizes[1]) / $(sizes[2]), or
630+
$(tuple(prod(sizes[1]))) / $(tuple(prod(sizes[2]))). Got
631+
$(size(u)), $(size(v))."""
632+
throw(DimensionMismatch(msg))
633+
end
634+
end
600635

601-
# fallback *, \ for FunctionOperator with no OOP method
636+
U = vec_u ? reshape(u, sizes[1]) : u
637+
V = vec_v ? reshape(v, sizes[2]) : v
638+
vec_output = vec_u | vec_v
602639

603-
function Base.:*(L::FunctionOperator{true,false}, u::AbstractArray)
604-
_, co = L.cache
605-
v = zero(co)
640+
return U, V, vec_output
641+
end
642+
end
606643

607-
_sizecheck(L, u, v)
644+
# operator application
645+
function Base.:*(L::FunctionOperator{iip,true}, u::AbstractArray) where{iip}
646+
_sizecheck(L, u, nothing)
647+
U, _, vec_output = _unvec(L, u, nothing)
608648

609-
L.op(v, u, L.p, L.t; L.traits.kwargs...)
649+
V = L.op(U, L.p, L.t; L.traits.kwargs...)
650+
651+
vec_output ? vec(V) : V
610652
end
611653

612-
function Base.:\(L::FunctionOperator{true,false}, u::AbstractArray)
613-
ci, _ = L.cache
614-
v = zero(ci)
654+
function Base.:\(L::FunctionOperator{iip,true}, v::AbstractArray) where{iip}
655+
_sizecheck(L, nothing, v)
656+
_, V, vec_output = _unvec(L, nothing, v)
615657

616-
_sizecheck(L, v, u)
658+
U = L.op_inverse(V, L.p, L.t; L.traits.kwargs...)
617659

618-
L.op_inverse(v, u, L.p, L.t; L.traits.kwargs...)
660+
vec_output ? vec(U) : U
619661
end
620662

621663
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true}, u::AbstractArray)
622-
623664
_sizecheck(L, u, v)
665+
U, V, vec_output = _unvec(L, u, v)
666+
667+
L.op(V, U, L.p, L.t; L.traits.kwargs...)
624668

625-
L.op(v, u, L.p, L.t; L.traits.kwargs...)
669+
vec_output ? vec(V) : V
626670
end
627671

628-
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{false}, u::AbstractArray, args...)
629-
@error "LinearAlgebra.mul! not defined for out-of-place FunctionOperators"
672+
function LinearAlgebra.mul!(::AbstractArray, L::FunctionOperator{false}, ::AbstractArray, args...)
673+
@error "LinearAlgebra.mul! not defined for out-of-place operator $L"
630674
end
631675

632676
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true, oop, false}, u::AbstractArray, α, β) where{oop}
633-
_, co = L.cache
677+
_, Co = L.cache
634678

635679
_sizecheck(L, u, v)
680+
U, V, _ = _unvec(L, u, v)
636681

637-
copy!(co, v)
638-
mul!(v, L, u)
639-
axpby!(β, co, α, v)
682+
copy!(Co, V)
683+
L.op(V, U, L.p, L.t; L.traits.kwargs...) # mul!(V, L, U)
684+
axpby!(β, Co, α, V)
685+
686+
v
640687
end
641688

642689
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true, oop, true}, u::AbstractArray, α, β) where{oop}
643-
644690
_sizecheck(L, u, v)
691+
U, V, _ = _unvec(L, u, v)
645692

646-
L.op(v, u, L.p, L.t, α, β; L.traits.kwargs...)
693+
L.op(V, U, L.p, L.t, α, β; L.traits.kwargs...)
694+
695+
v
647696
end
648697

649-
function LinearAlgebra.ldiv!(v::AbstractArray, L::FunctionOperator{true}, u::AbstractArray)
698+
function LinearAlgebra.ldiv!(u::AbstractArray, L::FunctionOperator{true}, v::AbstractArray)
699+
_sizecheck(L, u, v)
700+
U, V, _ = _unvec(L, u, v)
650701

651-
_sizecheck(L, v, u)
702+
L.op_inverse(U, V, L.p, L.t; L.traits.kwargs...)
652703

653-
L.op_inverse(v, u, L.p, L.t; L.traits.kwargs...)
704+
u
654705
end
655706

656707
function LinearAlgebra.ldiv!(L::FunctionOperator{true}, u::AbstractArray)
657-
ci, _ = L.cache
708+
V, _ = L.cache
709+
710+
_sizecheck(L, u, V)
711+
U, _, _ = _unvec(L, u, nothing)
712+
713+
copy!(V, U)
714+
L.op_inverse(U, V, L.p, L.t; L.traits.kwargs...) # ldiv!(U, L, V)
658715

659-
copy!(ci, u)
660-
ldiv!(u, L, ci)
716+
u
661717
end
662718

663719
function LinearAlgebra.ldiv!(v::AbstractArray, L::FunctionOperator{false}, u::AbstractArray)
664-
@error "LinearAlgebra.ldiv! not defined for out-of-place FunctionOperators"
720+
@error "LinearAlgebra.ldiv! not defined for out-of-place $L"
665721
end
666722

667723
function LinearAlgebra.ldiv!(L::FunctionOperator{false}, u::AbstractArray)
668-
@error "LinearAlgebra.ldiv! not defined for out-of-place FunctionOperators"
724+
@error "LinearAlgebra.ldiv! not defined for out-of-place $L"
669725
end
670726
#

test/func.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,25 @@ NK = N * K
4747
L = FunctionOperator(f, u, v; kw...)
4848
L = cache_operator(L, u)
4949

50+
# test with ND-arrays
5051
@test _mul(A, u) L(u, p, t) L * u mul!(zero(v), L, u)
5152
@test α * _mul(A, u)+ β * v mul!(copy(v), L, u, α, β)
5253

5354
if sz_in == sz_out
5455
@test _div(A, v) L \ v ldiv!(zero(u), L, v) ldiv!(L, copy(v))
5556
end
56-
end
57+
58+
# test with vec(Array)
59+
@test vec(_mul(A, u)) L(vec(u), p, t) L * vec(u) mul!(vec(zero(v)), L, vec(u))
60+
@test vec* _mul(A, u)+ β * v) mul!(vec(copy(v)), L, vec(u), α, β)
61+
62+
if sz_in == sz_out
63+
@test vec(_div(A, v)) L \ vec(v) ldiv!(vec(zero(u)), L, vec(v)) ldiv!(L, vec(copy(v)))
64+
end
65+
66+
@test_throws DimensionMismatch mul!(vec(v), L, u)
67+
@test_throws DimensionMismatch mul!(v, L, vec(u))
68+
end # for
5769

5870
end
5971

0 commit comments

Comments
 (0)