@@ -175,7 +175,7 @@ function FunctionOperator(op,
175175 msg = """ `FunctionOperator` constructed with `batch = true` only
176176 accepts `AbstractVecOrMat` types with
177177 `size(L, 2) == size(u, 1)`."""
178- ArgumentError (msg) |> throw
178+ throw ( ArgumentError (msg))
179179 end
180180
181181 if input isa AbstractMatrix
@@ -186,7 +186,7 @@ function FunctionOperator(op,
186186 array, $(typeof (input)) , has size $(size (input)) , whereas
187187 output array, $(typeof (output)) , has size
188188 $(size (output)) ."""
189- ArgumentError (msg) |> throw
189+ throw ( ArgumentError (msg))
190190 end
191191 end
192192 end
@@ -343,14 +343,14 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray)
343343 if ! isa (u, AbstractVecOrMat)
344344 msg = """ $L constructed with `batch = true` only accepts
345345 `AbstractVecOrMat` types with `size(L, 2) == size(u, 1)`."""
346- ArgumentError (msg) |> throw
346+ throw ( ArgumentError (msg))
347347 end
348348
349349 if size (L, 2 ) != size (u, 1 )
350350 msg = """ Second dimension of $L of size $(size (L))
351351 is not consistent with first dimension of input array `u`
352352 of size $(size (u)) ."""
353- DimensionMismatch (msg) |> throw
353+ throw ( DimensionMismatch (msg))
354354 end
355355
356356 M = size (L, 1 )
@@ -491,7 +491,7 @@ function Base.resize!(L::FunctionOperator, n::Integer)
491491 if length (L. traits. sizes[1 ]) != 1
492492 msg = """ `Base.resize!` is only supported by $L whose input/output
493493 arrays are `AbstractVector`s."""
494- MethodError (msg) |> throw
494+ throw ( MethodError (msg))
495495 end
496496
497497 for op in getops (L)
@@ -540,131 +540,187 @@ has_ldiv(L::FunctionOperator{iip}) where{iip} = !(L.op_inverse isa Nothing)
540540has_ldiv! (L:: FunctionOperator{iip} ) where {iip} = iip & ! (L. op_inverse isa Nothing)
541541
542542function _sizecheck (L:: FunctionOperator , u, v)
543-
543+ sizes = L . traits . sizes
544544 if L. traits. batch
545545 if ! isnothing (u)
546546 if ! isa (u, AbstractVecOrMat)
547547 msg = """ $L constructed with `batch = true` only
548548 accept input arrays that are `AbstractVecOrMat`s with
549- `size(L, 2) == size(u, 1)`."""
550- ArgumentError (msg) |> throw
549+ `size(L, 2) == size(u, 1)`. Recieved $( typeof (u)) . """
550+ throw ( ArgumentError (msg))
551551 end
552552
553- if size (u ) != L . traits . sizes[ 1 ]
554- msg = """ $L expects input arrays of size $(L . traits . sizes[ 1 ]) .
555- Recievd array of size $(size (u)) ."""
556- DimensionMismatch (msg) |> throw
553+ if size (L, 2 ) != size (u, 1 )
554+ msg = """ $L accepts input `AbstractVecOrMat`s of size
555+ ( $( size (L, 2 )) , K). Recievd array of size $(size (u)) ."""
556+ throw ( DimensionMismatch (msg))
557557 end
558- end
558+ end # u
559559
560560 if ! isnothing (v)
561- if size (v) != L. traits. sizes[2 ]
562- msg = """ $L expects input arrays of size $(L. traits. sizes[1 ]) .
563- Recievd array of size $(size (v)) ."""
564- DimensionMismatch (msg) |> throw
561+ if ! isa (v, AbstractVecOrMat)
562+ msg = """ $L constructed with `batch = true` only
563+ returns output arrays that are `AbstractVecOrMat`s with
564+ `size(L, 1) == size(v, 1)`. Recieved $(typeof (v)) ."""
565+ throw (ArgumentError (msg))
565566 end
566- end
567+
568+ if size (L, 1 ) != size (v, 1 )
569+ msg = """ $L accepts output `AbstractVecOrMat`s of size
570+ ($(size (L, 1 )) , K). Recievd array of size $(size (v)) ."""
571+ throw (DimensionMismatch (msg))
572+ end
573+ end # v
574+
575+ if ! isnothing (u) & ! isnothing (v)
576+ if size (u, 2 ) != size (v, 2 )
577+ msg = """ input array $u , and output array, $v , must have the
578+ same batch size (i.e. length of second dimension). Got
579+ $(size (u)) , $(size (v)) . If you encounter this error during
580+ an in-place evaluation (`LinearAlgebra.mul!`, `ldiv!`),
581+ ensure that the operator $L has been cached with an input
582+ array of the correct size. Do so by calling
583+ `L = cache_operator(L, u)`."""
584+ throw (DimensionMismatch (msg))
585+ end
586+ end # u, v
587+
567588 else # !batch
589+
568590 if ! isnothing (u)
569- if size (u) != L. traits. sizes[1 ]
570- msg = """ $L expects input arrays of size $(L. traits. sizes[1 ]) .
571- Recievd array of size $(size (u)) ."""
572- DimensionMismatch (msg) |> throw
591+ if size (u) ∉ (sizes[1 ], tuple (size (L, 2 )),)
592+ msg = """ $L recievd input array of size $(size (u)) , but only
593+ accepts input arrays of size $(sizes[1 ]) , or vectors like
594+ `vec(u)` of size $(tuple (prod (sizes[1 ]))) ."""
595+ throw (DimensionMismatch (msg))
573596 end
574- end
597+ end # u
575598
576599 if ! isnothing (v)
577- if size (v) != L. traits. sizes[2 ]
578- msg = """ $L expects input arrays of size $(L. traits. sizes[1 ]) .
579- Recievd array of size $(size (v)) ."""
580- DimensionMismatch (msg) |> throw
600+ if size (v) ∉ (sizes[2 ], tuple (size (L, 1 )),)
601+ msg = """ $L recievd output array of size $(size (v)) , but only
602+ accepts output arrays of size $(sizes[2 ]) , or vectors like
603+ `vec(u)` of size $(tuple (prod (sizes[2 ]))) """
604+ throw (DimensionMismatch (msg))
581605 end
582- end
606+ end # v
583607 end # batch
584608
585609 return
586610end
587611
588- # operator application
589- function Base.:* (L:: FunctionOperator{iip,true} , u:: AbstractArray ) where {iip}
590- _sizecheck (L, u, nothing )
612+ function _unvec (L:: FunctionOperator , u, v)
613+ if L. traits. batch
614+ return u, v, false
615+ else
616+ sizes = L. traits. sizes
591617
592- L. op (u, L. p, L. t; L. traits. kwargs... )
593- end
618+ # no need to vec since expected input/output are AbstractVectors
619+ if length (sizes[1 ]) == 1
620+ return u, v, false
621+ end
594622
595- function Base.: \ (L :: FunctionOperator{iip,true} , u :: AbstractArray ) where {iip}
596- _sizecheck (L, nothing , u)
623+ vec_u = isnothing (u) ? false : size (u) != sizes[ 1 ]
624+ vec_v = isnothing (v) ? false : size (v) != sizes[ 2 ]
597625
598- L. op_inverse (u, L. p, L. t; L. traits. kwargs... )
599- end
626+ if ! isnothing (u) & ! isnothing (v)
627+ if (vec_u & ! vec_v) | (! vec_u & vec_v)
628+ msg = """ Input / output to $L can either be of sizes
629+ $(sizes[1 ]) / $(sizes[2 ]) , or
630+ $(tuple (prod (sizes[1 ]))) / $(tuple (prod (sizes[2 ]))) . Got
631+ $(size (u)) , $(size (v)) ."""
632+ throw (DimensionMismatch (msg))
633+ end
634+ end
600635
601- # fallback *, \ for FunctionOperator with no OOP method
636+ U = vec_u ? reshape (u, sizes[1 ]) : u
637+ V = vec_v ? reshape (v, sizes[2 ]) : v
638+ vec_output = vec_u | vec_v
602639
603- function Base.: * (L :: FunctionOperator{true,false} , u :: AbstractArray )
604- _, co = L . cache
605- v = zero (co)
640+ return U, V, vec_output
641+ end
642+ end
606643
607- _sizecheck (L, u, v)
644+ # operator application
645+ function Base.:* (L:: FunctionOperator{iip,true} , u:: AbstractArray ) where {iip}
646+ _sizecheck (L, u, nothing )
647+ U, _, vec_output = _unvec (L, u, nothing )
608648
609- L. op (v, u, L. p, L. t; L. traits. kwargs... )
649+ V = L. op (U, L. p, L. t; L. traits. kwargs... )
650+
651+ vec_output ? vec (V) : V
610652end
611653
612- function Base.:\ (L:: FunctionOperator{true,false } , u :: AbstractArray )
613- ci, _ = L . cache
614- v = zero (ci )
654+ function Base.:\ (L:: FunctionOperator{iip,true } , v :: AbstractArray ) where {iip}
655+ _sizecheck (L, nothing , v)
656+ _, V, vec_output = _unvec (L, nothing , v )
615657
616- _sizecheck (L, v, u )
658+ U = L . op_inverse (V, L . p, L . t; L . traits . kwargs ... )
617659
618- L . op_inverse (v, u, L . p, L . t; L . traits . kwargs ... )
660+ vec_output ? vec (U) : U
619661end
620662
621663function LinearAlgebra. mul! (v:: AbstractArray , L:: FunctionOperator{true} , u:: AbstractArray )
622-
623664 _sizecheck (L, u, v)
665+ U, V, vec_output = _unvec (L, u, v)
666+
667+ L. op (V, U, L. p, L. t; L. traits. kwargs... )
624668
625- L . op (v, u, L . p, L . t; L . traits . kwargs ... )
669+ vec_output ? vec (V) : V
626670end
627671
628- function LinearAlgebra. mul! (v :: AbstractArray , L:: FunctionOperator{false} , u :: AbstractArray , args... )
629- @error " LinearAlgebra.mul! not defined for out-of-place FunctionOperators "
672+ function LinearAlgebra. mul! (:: AbstractArray , L:: FunctionOperator{false} , :: AbstractArray , args... )
673+ @error " LinearAlgebra.mul! not defined for out-of-place operator $L "
630674end
631675
632676function LinearAlgebra. mul! (v:: AbstractArray , L:: FunctionOperator{true, oop, false} , u:: AbstractArray , α, β) where {oop}
633- _, co = L. cache
677+ _, Co = L. cache
634678
635679 _sizecheck (L, u, v)
680+ U, V, _ = _unvec (L, u, v)
636681
637- copy! (co, v)
638- mul! (v, L, u)
639- axpby! (β, co, α, v)
682+ copy! (Co, V)
683+ L. op (V, U, L. p, L. t; L. traits. kwargs... ) # mul!(V, L, U)
684+ axpby! (β, Co, α, V)
685+
686+ v
640687end
641688
642689function LinearAlgebra. mul! (v:: AbstractArray , L:: FunctionOperator{true, oop, true} , u:: AbstractArray , α, β) where {oop}
643-
644690 _sizecheck (L, u, v)
691+ U, V, _ = _unvec (L, u, v)
645692
646- L. op (v, u, L. p, L. t, α, β; L. traits. kwargs... )
693+ L. op (V, U, L. p, L. t, α, β; L. traits. kwargs... )
694+
695+ v
647696end
648697
649- function LinearAlgebra. ldiv! (v:: AbstractArray , L:: FunctionOperator{true} , u:: AbstractArray )
698+ function LinearAlgebra. ldiv! (u:: AbstractArray , L:: FunctionOperator{true} , v:: AbstractArray )
699+ _sizecheck (L, u, v)
700+ U, V, _ = _unvec (L, u, v)
650701
651- _sizecheck (L, v, u )
702+ L . op_inverse (U, V, L . p, L . t; L . traits . kwargs ... )
652703
653- L . op_inverse (v, u, L . p, L . t; L . traits . kwargs ... )
704+ u
654705end
655706
656707function LinearAlgebra. ldiv! (L:: FunctionOperator{true} , u:: AbstractArray )
657- ci, _ = L. cache
708+ V, _ = L. cache
709+
710+ _sizecheck (L, u, V)
711+ U, _, _ = _unvec (L, u, nothing )
712+
713+ copy! (V, U)
714+ L. op_inverse (U, V, L. p, L. t; L. traits. kwargs... ) # ldiv!(U, L, V)
658715
659- copy! (ci, u)
660- ldiv! (u, L, ci)
716+ u
661717end
662718
663719function LinearAlgebra. ldiv! (v:: AbstractArray , L:: FunctionOperator{false} , u:: AbstractArray )
664- @error " LinearAlgebra.ldiv! not defined for out-of-place FunctionOperators "
720+ @error " LinearAlgebra.ldiv! not defined for out-of-place $L "
665721end
666722
667723function LinearAlgebra. ldiv! (L:: FunctionOperator{false} , u:: AbstractArray )
668- @error " LinearAlgebra.ldiv! not defined for out-of-place FunctionOperators "
724+ @error " LinearAlgebra.ldiv! not defined for out-of-place $L "
669725end
670726#
0 commit comments