Skip to content

Commit 43241ea

Browse files
authored
Accomodate for rectangular matrices in copytrito! (#538)
1 parent b97643c commit 43241ea

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

src/host/linalg.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,12 @@ if isdefined(LinearAlgebra, :copytrito!)
111111
LinearAlgebra.BLAS.chkuplo(uplo)
112112
m,n = size(A)
113113
m1,n1 = size(B)
114-
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)"))
115114
if uplo == 'U'
115+
if n < m
116+
(m1 < n || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($n,$n)"))
117+
else
118+
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)"))
119+
end
116120
@kernel function U_kernel!(_A, _B)
117121
I = @index(Global, Cartesian)
118122
i, j = Tuple(I)
@@ -122,6 +126,11 @@ if isdefined(LinearAlgebra, :copytrito!)
122126
end
123127
U_kernel!(get_backend(B))(A, B; ndrange = size(A))
124128
else # uplo == 'L'
129+
if m < n
130+
(m1 < m || n1 < m) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$m)"))
131+
else
132+
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)"))
133+
end
125134
@kernel function L_kernel!(_A, _B)
126135
I = @index(Global, Cartesian)
127136
i, j = Tuple(I)

test/testsuite/linalg.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@
8585
B = zeros(T,n,n)
8686
@test compare(copytrito!, AT, B, A, uplo)
8787
end
88+
@testset for T in eltypes, uplo in ('L', 'U')
89+
n = 16
90+
m = 32
91+
A = uplo == 'U' ? rand(T,m,n) : rand(T,n,m)
92+
B = zeros(T,n,n)
93+
@test compare(copytrito!, AT, B, A, uplo)
94+
end
8895
end
8996
end
9097

0 commit comments

Comments
 (0)