@@ -173,7 +173,7 @@ function FunctionOperator(op,
173173 msg = """ `FunctionOperator` constructed with `batch = true` only
174174 accepts `AbstractVecOrMat` types with
175175 `size(L, 2) == size(u, 1)`."""
176- ArgumentError (msg) |> throw
176+ throw ( ArgumentError (msg))
177177 end
178178
179179 if input isa AbstractMatrix
@@ -184,7 +184,7 @@ function FunctionOperator(op,
184184 array, $(typeof (input)) , has size $(size (input)) , whereas
185185 output array, $(typeof (output)) , has size
186186 $(size (output)) ."""
187- ArgumentError (msg) |> throw
187+ throw ( ArgumentError (msg))
188188 end
189189 end
190190 end
@@ -340,14 +340,14 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray)
340340 if ! isa (u, AbstractVecOrMat)
341341 msg = """ $L constructed with `batch = true` only accepts
342342 `AbstractVecOrMat` types with `size(L, 2) == size(u, 1)`."""
343- ArgumentError (msg) |> throw
343+ throw ( ArgumentError (msg))
344344 end
345345
346346 if size (L, 2 ) != size (u, 1 )
347347 msg = """ Second dimension of $L of size $(size (L))
348348 is not consistent with first dimension of input array `u`
349349 of size $(size (u)) ."""
350- DimensionMismatch (msg) |> throw
350+ throw ( DimensionMismatch (msg))
351351 end
352352
353353 M = size (L, 1 )
@@ -486,7 +486,7 @@ function Base.resize!(L::FunctionOperator, n::Integer)
486486 if length (L. traits. sizes[1 ]) != 1
487487 msg = """ `Base.resize!` is only supported by $L whose input/output
488488 arrays are `AbstractVector`s."""
489- MethodError (msg) |> throw
489+ throw ( MethodError (msg))
490490 end
491491
492492 for op in getops (L)
@@ -534,131 +534,187 @@ has_ldiv(L::FunctionOperator{iip}) where{iip} = !(L.op_inverse isa Nothing)
534534has_ldiv! (L:: FunctionOperator{iip} ) where {iip} = iip & ! (L. op_inverse isa Nothing)
535535
536536function _sizecheck (L:: FunctionOperator , u, v)
537-
537+ sizes = L . traits . sizes
538538 if L. traits. batch
539539 if ! isnothing (u)
540540 if ! isa (u, AbstractVecOrMat)
541541 msg = """ $L constructed with `batch = true` only
542542 accept input arrays that are `AbstractVecOrMat`s with
543- `size(L, 2) == size(u, 1)`."""
544- ArgumentError (msg) |> throw
543+ `size(L, 2) == size(u, 1)`. Recieved $( typeof (u)) . """
544+ throw ( ArgumentError (msg))
545545 end
546546
547- if size (u ) != L . traits . sizes[ 1 ]
548- msg = """ $L expects input arrays of size $(L . traits . sizes[ 1 ]) .
549- Recievd array of size $(size (u)) ."""
550- DimensionMismatch (msg) |> throw
547+ if size (L, 2 ) != size (u, 1 )
548+ msg = """ $L accepts input `AbstractVecOrMat`s of size
549+ ( $( size (L, 2 )) , K). Recievd array of size $(size (u)) ."""
550+ throw ( DimensionMismatch (msg))
551551 end
552- end
552+ end # u
553553
554554 if ! isnothing (v)
555- if size (v) != L. traits. sizes[2 ]
556- msg = """ $L expects input arrays of size $(L. traits. sizes[1 ]) .
557- Recievd array of size $(size (v)) ."""
558- DimensionMismatch (msg) |> throw
555+ if ! isa (v, AbstractVecOrMat)
556+ msg = """ $L constructed with `batch = true` only
557+ returns output arrays that are `AbstractVecOrMat`s with
558+ `size(L, 1) == size(v, 1)`. Recieved $(typeof (v)) ."""
559+ throw (ArgumentError (msg))
559560 end
560- end
561+
562+ if size (L, 1 ) != size (v, 1 )
563+ msg = """ $L accepts output `AbstractVecOrMat`s of size
564+ ($(size (L, 1 )) , K). Recievd array of size $(size (v)) ."""
565+ throw (DimensionMismatch (msg))
566+ end
567+ end # v
568+
569+ if ! isnothing (u) & ! isnothing (v)
570+ if size (u, 2 ) != size (v, 2 )
571+ msg = """ input array $u , and output array, $v , must have the
572+ same batch size (i.e. length of second dimension). Got
573+ $(size (u)) , $(size (v)) . If you encounter this error during
574+ an in-place evaluation (`LinearAlgebra.mul!`, `ldiv!`),
575+ ensure that the operator $L has been cached with an input
576+ array of the correct size. Do so by calling
577+ `L = cache_operator(L, u)`."""
578+ throw (DimensionMismatch (msg))
579+ end
580+ end # u, v
581+
561582 else # !batch
583+
562584 if ! isnothing (u)
563- if size (u) != L. traits. sizes[1 ]
564- msg = """ $L expects input arrays of size $(L. traits. sizes[1 ]) .
565- Recievd array of size $(size (u)) ."""
566- DimensionMismatch (msg) |> throw
585+ if size (u) ∉ (sizes[1 ], tuple (size (L, 2 )),)
586+ msg = """ $L recievd input array of size $(size (u)) , but only
587+ accepts input arrays of size $(sizes[1 ]) , or vectors like
588+ `vec(u)` of size $(tuple (prod (sizes[1 ]))) ."""
589+ throw (DimensionMismatch (msg))
567590 end
568- end
591+ end # u
569592
570593 if ! isnothing (v)
571- if size (v) != L. traits. sizes[2 ]
572- msg = """ $L expects input arrays of size $(L. traits. sizes[1 ]) .
573- Recievd array of size $(size (v)) ."""
574- DimensionMismatch (msg) |> throw
594+ if size (v) ∉ (sizes[2 ], tuple (size (L, 1 )),)
595+ msg = """ $L recievd output array of size $(size (v)) , but only
596+ accepts output arrays of size $(sizes[2 ]) , or vectors like
597+ `vec(u)` of size $(tuple (prod (sizes[2 ]))) """
598+ throw (DimensionMismatch (msg))
575599 end
576- end
600+ end # v
577601 end # batch
578602
579603 return
580604end
581605
582- # operator application
583- function Base.:* (L:: FunctionOperator{iip,true} , u:: AbstractArray ) where {iip}
584- _sizecheck (L, u, nothing )
606+ function _unvec (L:: FunctionOperator , u, v)
607+ if L. traits. batch
608+ return u, v, false
609+ else
610+ sizes = L. traits. sizes
585611
586- L. op (u, L. p, L. t; L. traits. kwargs... )
587- end
612+ # no need to vec since expected input/output are AbstractVectors
613+ if length (sizes[1 ]) == 1
614+ return u, v, false
615+ end
588616
589- function Base.: \ (L :: FunctionOperator{iip,true} , u :: AbstractArray ) where {iip}
590- _sizecheck (L, nothing , u)
617+ vec_u = isnothing (u) ? false : size (u) != sizes[ 1 ]
618+ vec_v = isnothing (v) ? false : size (v) != sizes[ 2 ]
591619
592- L. op_inverse (u, L. p, L. t; L. traits. kwargs... )
593- end
620+ if ! isnothing (u) & ! isnothing (v)
621+ if (vec_u & ! vec_v) | (! vec_u & vec_v)
622+ msg = """ Input / output to $L can either be of sizes
623+ $(sizes[1 ]) / $(sizes[2 ]) , or
624+ $(tuple (prod (sizes[1 ]))) / $(tuple (prod (sizes[2 ]))) . Got
625+ $(size (u)) , $(size (v)) ."""
626+ throw (DimensionMismatch (msg))
627+ end
628+ end
594629
595- # fallback *, \ for FunctionOperator with no OOP method
630+ U = vec_u ? reshape (u, sizes[1 ]) : u
631+ V = vec_v ? reshape (v, sizes[2 ]) : v
632+ vec_output = vec_u | vec_v
596633
597- function Base.: * (L :: FunctionOperator{true,false} , u :: AbstractArray )
598- _, co = L . cache
599- v = zero (co)
634+ return U, V, vec_output
635+ end
636+ end
600637
601- _sizecheck (L, u, v)
638+ # operator application
639+ function Base.:* (L:: FunctionOperator{iip,true} , u:: AbstractArray ) where {iip}
640+ _sizecheck (L, u, nothing )
641+ U, _, vec_output = _unvec (L, u, nothing )
602642
603- L. op (v, u, L. p, L. t; L. traits. kwargs... )
643+ V = L. op (U, L. p, L. t; L. traits. kwargs... )
644+
645+ vec_output ? vec (V) : V
604646end
605647
606- function Base.:\ (L:: FunctionOperator{true,false } , u :: AbstractArray )
607- ci, _ = L . cache
608- v = zero (ci )
648+ function Base.:\ (L:: FunctionOperator{iip,true } , v :: AbstractArray ) where {iip}
649+ _sizecheck (L, nothing , v)
650+ _, V, vec_output = _unvec (L, nothing , v )
609651
610- _sizecheck (L, v, u )
652+ U = L . op_inverse (V, L . p, L . t; L . traits . kwargs ... )
611653
612- L . op_inverse (v, u, L . p, L . t; L . traits . kwargs ... )
654+ vec_output ? vec (U) : U
613655end
614656
615657function LinearAlgebra. mul! (v:: AbstractArray , L:: FunctionOperator{true} , u:: AbstractArray )
616-
617658 _sizecheck (L, u, v)
659+ U, V, vec_output = _unvec (L, u, v)
660+
661+ L. op (V, U, L. p, L. t; L. traits. kwargs... )
618662
619- L . op (v, u, L . p, L . t; L . traits . kwargs ... )
663+ vec_output ? vec (V) : V
620664end
621665
622- function LinearAlgebra. mul! (v :: AbstractArray , L:: FunctionOperator{false} , u :: AbstractArray , args... )
623- @error " LinearAlgebra.mul! not defined for out-of-place FunctionOperators "
666+ function LinearAlgebra. mul! (:: AbstractArray , L:: FunctionOperator{false} , :: AbstractArray , args... )
667+ @error " LinearAlgebra.mul! not defined for out-of-place operator $L "
624668end
625669
626670function LinearAlgebra. mul! (v:: AbstractArray , L:: FunctionOperator{true, oop, false} , u:: AbstractArray , α, β) where {oop}
627- _, co = L. cache
671+ _, Co = L. cache
628672
629673 _sizecheck (L, u, v)
674+ U, V, _ = _unvec (L, u, v)
630675
631- copy! (co, v)
632- mul! (v, L, u)
633- axpby! (β, co, α, v)
676+ copy! (Co, V)
677+ L. op (V, U, L. p, L. t; L. traits. kwargs... ) # mul!(V, L, U)
678+ axpby! (β, Co, α, V)
679+
680+ v
634681end
635682
636683function LinearAlgebra. mul! (v:: AbstractArray , L:: FunctionOperator{true, oop, true} , u:: AbstractArray , α, β) where {oop}
637-
638684 _sizecheck (L, u, v)
685+ U, V, _ = _unvec (L, u, v)
639686
640- L. op (v, u, L. p, L. t, α, β; L. traits. kwargs... )
687+ L. op (V, U, L. p, L. t, α, β; L. traits. kwargs... )
688+
689+ v
641690end
642691
643- function LinearAlgebra. ldiv! (v:: AbstractArray , L:: FunctionOperator{true} , u:: AbstractArray )
692+ function LinearAlgebra. ldiv! (u:: AbstractArray , L:: FunctionOperator{true} , v:: AbstractArray )
693+ _sizecheck (L, u, v)
694+ U, V, _ = _unvec (L, u, v)
644695
645- _sizecheck (L, v, u )
696+ L . op_inverse (U, V, L . p, L . t; L . traits . kwargs ... )
646697
647- L . op_inverse (v, u, L . p, L . t; L . traits . kwargs ... )
698+ u
648699end
649700
650701function LinearAlgebra. ldiv! (L:: FunctionOperator{true} , u:: AbstractArray )
651- ci, _ = L. cache
702+ V, _ = L. cache
703+
704+ _sizecheck (L, u, V)
705+ U, _, _ = _unvec (L, u, nothing )
706+
707+ copy! (V, U)
708+ L. op_inverse (U, V, L. p, L. t; L. traits. kwargs... ) # ldiv!(U, L, V)
652709
653- copy! (ci, u)
654- ldiv! (u, L, ci)
710+ u
655711end
656712
657713function LinearAlgebra. ldiv! (v:: AbstractArray , L:: FunctionOperator{false} , u:: AbstractArray )
658- @error " LinearAlgebra.ldiv! not defined for out-of-place FunctionOperators "
714+ @error " LinearAlgebra.ldiv! not defined for out-of-place $L "
659715end
660716
661717function LinearAlgebra. ldiv! (L:: FunctionOperator{false} , u:: AbstractArray )
662- @error " LinearAlgebra.ldiv! not defined for out-of-place FunctionOperators "
718+ @error " LinearAlgebra.ldiv! not defined for out-of-place $L "
663719end
664720#
0 commit comments