@@ -372,16 +372,27 @@ function LinearAlgebra._swap_rows!(B::ArrayPartition, i::Integer, j::Integer)
372372end
373373
374374# linalg mul! overloads for ArrayPartition
375- function LinearAlgebra. mul! (C:: T , A:: T , B:: AbstractArray ) where T<: ArrayPartition
376- @assert length (C. x) == length (A. x)
375+ function LinearAlgebra. mul! (C:: ArrayPartition , A:: ArrayPartition , B:: AbstractArray )
376+ if length (C. x) != length (A. x)
377+ throw (DimensionMismatch (" Length of C, $(length (C. x)) , does not match length of A, $(length (A. x)) " ))
378+ end
379+
377380 for index = 1 : length (C. x)
378381 mul! (C. x[index], A. x[index], B)
379382 end
383+ return C
380384end
381385
382- function LinearAlgebra. mul! (C:: T , A:: T , B:: T ) where T<: ArrayPartition
383- @assert length (C. x) == length (A. x) == length (B. x)
386+ function LinearAlgebra. mul! (C:: ArrayPartition , A:: ArrayPartition , B:: ArrayPartition )
387+ if length (C. x) != length (A. x)
388+ throw (DimensionMismatch (" Length of C, $(length (C. x)) , does not match length of A, $(length (B. x)) " ))
389+ end
390+ if length (A. x) != length (B. x)
391+ throw (DimensionMismatch (" Length of A, $(length (A. x)) , does not match length of B, $(length (B. x)) " ))
392+ end
393+
384394 for index = 1 : length (C. x)
385395 mul! (C. x[index], A. x[index], B. x[index])
386396 end
397+ return C
387398end
0 commit comments