Skip to content

Commit 664d609

Browse files
committed
allow functionoperator (batch = false) to accept/return vec'd arrays
1 parent ef163d1 commit 664d609

File tree

1 file changed

+122
-66
lines changed

1 file changed

+122
-66
lines changed

src/func.jl

Lines changed: 122 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ function FunctionOperator(op,
173173
msg = """`FunctionOperator` constructed with `batch = true` only
174174
accepts `AbstractVecOrMat` types with
175175
`size(L, 2) == size(u, 1)`."""
176-
ArgumentError(msg) |> throw
176+
throw(ArgumentError(msg))
177177
end
178178

179179
if input isa AbstractMatrix
@@ -184,7 +184,7 @@ function FunctionOperator(op,
184184
array, $(typeof(input)), has size $(size(input)), whereas
185185
output array, $(typeof(output)), has size
186186
$(size(output))."""
187-
ArgumentError(msg) |> throw
187+
throw(ArgumentError(msg))
188188
end
189189
end
190190
end
@@ -340,14 +340,14 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray)
340340
if !isa(u, AbstractVecOrMat)
341341
msg = """$L constructed with `batch = true` only accepts
342342
`AbstractVecOrMat` types with `size(L, 2) == size(u, 1)`."""
343-
ArgumentError(msg) |> throw
343+
throw(ArgumentError(msg))
344344
end
345345

346346
if size(L, 2) != size(u, 1)
347347
msg = """Second dimension of $L of size $(size(L))
348348
is not consistent with first dimension of input array `u`
349349
of size $(size(u))."""
350-
DimensionMismatch(msg) |> throw
350+
throw(DimensionMismatch(msg))
351351
end
352352

353353
M = size(L, 1)
@@ -486,7 +486,7 @@ function Base.resize!(L::FunctionOperator, n::Integer)
486486
if length(L.traits.sizes[1]) != 1
487487
msg = """`Base.resize!` is only supported by $L whose input/output
488488
arrays are `AbstractVector`s."""
489-
MethodError(msg) |> throw
489+
throw(MethodError(msg))
490490
end
491491

492492
for op in getops(L)
@@ -534,131 +534,187 @@ has_ldiv(L::FunctionOperator{iip}) where{iip} = !(L.op_inverse isa Nothing)
534534
has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothing)
535535

536536
function _sizecheck(L::FunctionOperator, u, v)
537-
537+
sizes = L.traits.sizes
538538
if L.traits.batch
539539
if !isnothing(u)
540540
if !isa(u, AbstractVecOrMat)
541541
msg = """$L constructed with `batch = true` only
542542
accept input arrays that are `AbstractVecOrMat`s with
543-
`size(L, 2) == size(u, 1)`."""
544-
ArgumentError(msg) |> throw
543+
`size(L, 2) == size(u, 1)`. Recieved $(typeof(u))."""
544+
throw(ArgumentError(msg))
545545
end
546546

547-
if size(u) != L.traits.sizes[1]
548-
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
549-
Recievd array of size $(size(u))."""
550-
DimensionMismatch(msg) |> throw
547+
if size(L, 2) != size(u, 1)
548+
msg = """$L accepts input `AbstractVecOrMat`s of size
549+
($(size(L, 2)), K). Recievd array of size $(size(u))."""
550+
throw(DimensionMismatch(msg))
551551
end
552-
end
552+
end # u
553553

554554
if !isnothing(v)
555-
if size(v) != L.traits.sizes[2]
556-
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
557-
Recievd array of size $(size(v))."""
558-
DimensionMismatch(msg) |> throw
555+
if !isa(v, AbstractVecOrMat)
556+
msg = """$L constructed with `batch = true` only
557+
returns output arrays that are `AbstractVecOrMat`s with
558+
`size(L, 1) == size(v, 1)`. Recieved $(typeof(v))."""
559+
throw(ArgumentError(msg))
559560
end
560-
end
561+
562+
if size(L, 1) != size(v, 1)
563+
msg = """$L accepts output `AbstractVecOrMat`s of size
564+
($(size(L, 1)), K). Recievd array of size $(size(v))."""
565+
throw(DimensionMismatch(msg))
566+
end
567+
end # v
568+
569+
if !isnothing(u) & !isnothing(v)
570+
if size(u, 2) != size(v, 2)
571+
msg = """input array $u, and output array, $v, must have the
572+
same batch size (i.e. length of second dimension). Got
573+
$(size(u)), $(size(v)). If you encounter this error during
574+
an in-place evaluation (`LinearAlgebra.mul!`, `ldiv!`),
575+
ensure that the operator $L has been cached with an input
576+
array of the correct size. Do so by calling
577+
`L = cache_operator(L, u)`."""
578+
throw(DimensionMismatch(msg))
579+
end
580+
end # u, v
581+
561582
else # !batch
583+
562584
if !isnothing(u)
563-
if size(u) != L.traits.sizes[1]
564-
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
565-
Recievd array of size $(size(u))."""
566-
DimensionMismatch(msg) |> throw
585+
if size(u) (sizes[1], tuple(size(L, 2)),)
586+
msg = """$L recievd input array of size $(size(u)), but only
587+
accepts input arrays of size $(sizes[1]), or vectors like
588+
`vec(u)` of size $(tuple(prod(sizes[1])))."""
589+
throw(DimensionMismatch(msg))
567590
end
568-
end
591+
end # u
569592

570593
if !isnothing(v)
571-
if size(v) != L.traits.sizes[2]
572-
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
573-
Recievd array of size $(size(v))."""
574-
DimensionMismatch(msg) |> throw
594+
if size(v) (sizes[2], tuple(size(L, 1)),)
595+
msg = """$L recievd output array of size $(size(v)), but only
596+
accepts output arrays of size $(sizes[2]), or vectors like
597+
`vec(u)` of size $(tuple(prod(sizes[2])))"""
598+
throw(DimensionMismatch(msg))
575599
end
576-
end
600+
end # v
577601
end # batch
578602

579603
return
580604
end
581605

582-
# operator application
583-
function Base.:*(L::FunctionOperator{iip,true}, u::AbstractArray) where{iip}
584-
_sizecheck(L, u, nothing)
606+
function _unvec(L::FunctionOperator, u, v)
607+
if L.traits.batch
608+
return u, v, false
609+
else
610+
sizes = L.traits.sizes
585611

586-
L.op(u, L.p, L.t; L.traits.kwargs...)
587-
end
612+
# no need to vec since expected input/output are AbstractVectors
613+
if length(sizes[1]) == 1
614+
return u, v, false
615+
end
588616

589-
function Base.:\(L::FunctionOperator{iip,true}, u::AbstractArray) where{iip}
590-
_sizecheck(L, nothing, u)
617+
vec_u = isnothing(u) ? false : size(u) != sizes[1]
618+
vec_v = isnothing(v) ? false : size(v) != sizes[2]
591619

592-
L.op_inverse(u, L.p, L.t; L.traits.kwargs...)
593-
end
620+
if !isnothing(u) & !isnothing(v)
621+
if (vec_u & !vec_v) | (!vec_u & vec_v)
622+
msg = """Input/output to $L can either be of sizes
623+
$(sizes[1])/ $(sizes[2]), or
624+
$(tuple(prod(sizes[1])))/ $(tuple(prod(sizes[2]))). Got
625+
$(size(u)), $(size(v))."""
626+
throw(DimensionMismatch(msg))
627+
end
628+
end
594629

595-
# fallback *, \ for FunctionOperator with no OOP method
630+
U = vec_u ? reshape(u, sizes[1]) : u
631+
V = vec_v ? reshape(v, sizes[2]) : v
632+
vec_output = vec_u | vec_v
596633

597-
function Base.:*(L::FunctionOperator{true,false}, u::AbstractArray)
598-
_, co = L.cache
599-
v = zero(co)
634+
return U, V, vec_output
635+
end
636+
end
600637

601-
_sizecheck(L, u, v)
638+
# operator application
639+
function Base.:*(L::FunctionOperator{iip,true}, u::AbstractArray) where{iip}
640+
_sizecheck(L, u, nothing)
641+
U, _, vec_output = _unvec(L, u, nothing)
602642

603-
L.op(v, u, L.p, L.t; L.traits.kwargs...)
643+
V = L.op(U, L.p, L.t; L.traits.kwargs...)
644+
645+
vec_output ? vec(V) : V
604646
end
605647

606-
function Base.:\(L::FunctionOperator{true,false}, u::AbstractArray)
607-
ci, _ = L.cache
608-
v = zero(ci)
648+
function Base.:\(L::FunctionOperator{iip,true}, v::AbstractArray) where{iip}
649+
_sizecheck(L, nothing, v)
650+
_, V, vec_output = _unvec(L, nothing, v)
609651

610-
_sizecheck(L, v, u)
652+
U = L.op_inverse(V, L.p, L.t; L.traits.kwargs...)
611653

612-
L.op_inverse(v, u, L.p, L.t; L.traits.kwargs...)
654+
vec_output ? vec(U) : U
613655
end
614656

615657
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true}, u::AbstractArray)
616-
617658
_sizecheck(L, u, v)
659+
U, V, vec_output = _unvec(L, u, v)
660+
661+
L.op(V, U, L.p, L.t; L.traits.kwargs...)
618662

619-
L.op(v, u, L.p, L.t; L.traits.kwargs...)
663+
vec_output ? vec(V) : V
620664
end
621665

622-
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{false}, u::AbstractArray, args...)
623-
@error "LinearAlgebra.mul! not defined for out-of-place FunctionOperators"
666+
function LinearAlgebra.mul!(::AbstractArray, L::FunctionOperator{false}, ::AbstractArray, args...)
667+
@error "LinearAlgebra.mul! not defined for out-of-place operator $L"
624668
end
625669

626670
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true, oop, false}, u::AbstractArray, α, β) where{oop}
627-
_, co = L.cache
671+
_, Co = L.cache
628672

629673
_sizecheck(L, u, v)
674+
U, V, _ = _unvec(L, u, v)
630675

631-
copy!(co, v)
632-
mul!(v, L, u)
633-
axpby!(β, co, α, v)
676+
copy!(Co, V)
677+
L.op(V, U, L.p, L.t; L.traits.kwargs...) # mul!(V, L, U)
678+
axpby!(β, Co, α, V)
679+
680+
v
634681
end
635682

636683
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true, oop, true}, u::AbstractArray, α, β) where{oop}
637-
638684
_sizecheck(L, u, v)
685+
U, V, _ = _unvec(L, u, v)
639686

640-
L.op(v, u, L.p, L.t, α, β; L.traits.kwargs...)
687+
L.op(V, U, L.p, L.t, α, β; L.traits.kwargs...)
688+
689+
v
641690
end
642691

643-
function LinearAlgebra.ldiv!(v::AbstractArray, L::FunctionOperator{true}, u::AbstractArray)
692+
function LinearAlgebra.ldiv!(u::AbstractArray, L::FunctionOperator{true}, v::AbstractArray)
693+
_sizecheck(L, u, v)
694+
U, V, _ = _unvec(L, u, v)
644695

645-
_sizecheck(L, v, u)
696+
L.op_inverse(U, V, L.p, L.t; L.traits.kwargs...)
646697

647-
L.op_inverse(v, u, L.p, L.t; L.traits.kwargs...)
698+
u
648699
end
649700

650701
function LinearAlgebra.ldiv!(L::FunctionOperator{true}, u::AbstractArray)
651-
ci, _ = L.cache
702+
V, _ = L.cache
703+
704+
_sizecheck(L, u, V)
705+
U, _, _ = _unvec(L, u, nothing)
706+
707+
copy!(V, U)
708+
L.op_inverse(U, V, L.p, L.t; L.traits.kwargs...) # ldiv!(U, L, V)
652709

653-
copy!(ci, u)
654-
ldiv!(u, L, ci)
710+
u
655711
end
656712

657713
function LinearAlgebra.ldiv!(v::AbstractArray, L::FunctionOperator{false}, u::AbstractArray)
658-
@error "LinearAlgebra.ldiv! not defined for out-of-place FunctionOperators"
714+
@error "LinearAlgebra.ldiv! not defined for out-of-place $L"
659715
end
660716

661717
function LinearAlgebra.ldiv!(L::FunctionOperator{false}, u::AbstractArray)
662-
@error "LinearAlgebra.ldiv! not defined for out-of-place FunctionOperators"
718+
@error "LinearAlgebra.ldiv! not defined for out-of-place $L"
663719
end
664720
#

0 commit comments

Comments
 (0)