Skip to content

Commit 59c030a

Browse files
Merge pull request #113 from Lewih/ArrayPartition-New-LA-overloads
Added linalg mul! overloads for ArrayPartition with tests
2 parents c24e54f + 35248a6 commit 59c030a

File tree

3 files changed

+50
-3
lines changed

3 files changed

+50
-3
lines changed

src/array_partition.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,29 @@ function LinearAlgebra._swap_rows!(B::ArrayPartition, i::Integer, j::Integer)
370370
B[i], B[j] = B[j], B[i]
371371
return B
372372
end
373+
374+
# linalg mul! overloads for ArrayPartition
375+
function LinearAlgebra.mul!(C::ArrayPartition, A::ArrayPartition, B::AbstractArray)
376+
if length(C.x) != length(A.x)
377+
throw(DimensionMismatch("Length of C, $(length(C.x)), does not match length of A, $(length(A.x))"))
378+
end
379+
380+
for index = 1:length(C.x)
381+
mul!(C.x[index], A.x[index], B)
382+
end
383+
return C
384+
end
385+
386+
function LinearAlgebra.mul!(C::ArrayPartition, A::ArrayPartition, B::ArrayPartition)
387+
if length(C.x) != length(A.x)
388+
throw(DimensionMismatch("Length of C, $(length(C.x)), does not match length of A, $(length(B.x))"))
389+
end
390+
if length(A.x) != length(B.x)
391+
throw(DimensionMismatch("Length of A, $(length(A.x)), does not match length of B, $(length(B.x))"))
392+
end
393+
394+
for index = 1:length(C.x)
395+
mul!(C.x[index], A.x[index], B.x[index])
396+
end
397+
return C
398+
end

test/linalg.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ b = ArrayPartition(bb)
88
A = randn(MersenneTwister(123), n+m, n+m)
99

1010
for T in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular)
11-
B = T(A)
11+
local B = T(A)
1212
@test B*Array(B \ b) b
1313
bbb = copy(b)
1414
@test ldiv!(bbb, B, b) === bbb
@@ -26,3 +26,24 @@ for ff in (lu, svd, qr)
2626
@test ldiv!(FF, bbb) === bbb
2727
@test A*bbb b
2828
end
29+
30+
# linalg mul! overloads
31+
n, m, l = 5, 6, 7
32+
bb = rand(n, n), rand(m, n), rand(l, n)
33+
cc = rand(n), rand(n), rand(n)
34+
dd = rand(n), rand(m), rand(l)
35+
b = ArrayPartition(bb)
36+
c = ArrayPartition(cc)
37+
d = ArrayPartition(dd)
38+
A = rand(n)
39+
for T in (Array{Float64}, Array{ComplexF64},)
40+
local B = T(A)
41+
mul!(d, b, A)
42+
for i = 1:length(c.x)
43+
@test d.x[i] == b.x[i] * A
44+
end
45+
mul!(d, b, c)
46+
for i = 1:length(d.x)
47+
@test d.x[i] == b.x[i] * c.x[i]
48+
end
49+
end

test/partitions_test.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ S = [
111111
]
112112

113113
for sizes in S
114-
x = ArrayPartition( randn.(sizes[1]) )
115-
y = ArrayPartition( zeros.(sizes[2]) )
114+
local x = ArrayPartition( randn.(sizes[1]) )
115+
local y = ArrayPartition( zeros.(sizes[2]) )
116116
y_array = zeros(length(x))
117117
copyto!(y,x) #testing Base.copyto!(dest::ArrayPartition,A::ArrayPartition)
118118
copyto!(y_array,x) #testing Base.copyto!(dest::Array,A::ArrayPartition)

0 commit comments

Comments
 (0)