Skip to content

Commit 6c17cce

Browse files
Rabab53jpsamaroo
andcommitted
DArray: Add proper transpose/adjoint copy/collect
Co-authored-by: Julian P Samaroo <[email protected]>
1 parent 961768c commit 6c17cce

File tree

2 files changed

+29
-61
lines changed

2 files changed

+29
-61
lines changed

src/array/matrix.jl

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,22 @@
1-
struct Transpose{T,N} <: ArrayOp{T,N}
2-
f::Function
3-
input::ArrayOp
4-
end
5-
6-
function Transpose(f,x::ArrayOp)
7-
@assert 1 <= ndims(x) && ndims(x) <= 2
8-
Transpose{eltype(x), 2}(f,x)
9-
end
10-
function size(x::Transpose)
11-
sz = size(x.input)
12-
if length(sz) == 1
13-
(1, sz[1])
14-
else
15-
(sz[2], sz[1])
1+
# Transpose/Adjoint
2+
3+
function copydiag(f, A::DArray{T, 2}) where T
4+
Ac = A.chunks
5+
Ac_copy = Matrix{Any}(undef, size(Ac, 2), size(Ac, 1))
6+
_copytile(f, Ac) = copy(f(Ac))
7+
for i in 1:size(Ac, 1), j in 1:size(Ac, 2)
8+
Ac_copy[j, i] = Dagger.@spawn _copytile(f, Ac[i, j])
169
end
10+
return DArray(T, ArrayDomain(1:size(A,2), 1:size(A,1)), domainchunks(A)', Ac_copy, A.partitioning)
1711
end
12+
Base.fetch(A::Adjoint{T, <:DArray{T, 2}}) where T = copydiag(Adjoint, parent(A))
13+
Base.fetch(A::Transpose{T, <:DArray{T, 2}}) where T = copydiag(Transpose, parent(A))
14+
Base.copy(A::Adjoint{T, <:DArray{T, 2}}) where T = fetch(A)
15+
Base.copy(A::Transpose{T, <:DArray{T, 2}}) where T = fetch(A)
16+
Base.collect(A::Adjoint{T, <:DArray{T, 2}}) where T = collect(copy(A))
17+
Base.collect(A::Transpose{T, <:DArray{T, 2}}) where T = collect(copy(A))
1818

19-
transpose(x::ArrayOp) = _to_darray(Transpose(transpose, x))
20-
transpose(x::Union{Chunk, EagerThunk}) = @spawn transpose(x)
21-
22-
adjoint(x::ArrayOp) = _to_darray(Transpose(adjoint, x))
23-
adjoint(x::Union{Chunk, EagerThunk}) = @spawn adjoint(x)
24-
25-
function adjoint(x::ArrayDomain{2})
26-
d = indexes(x)
27-
ArrayDomain(d[2], d[1])
28-
end
29-
function adjoint(x::ArrayDomain{1})
30-
d = indexes(x)
31-
ArrayDomain(1:1, d[1])
32-
end
33-
34-
function adjoint(x::Blocks{2})
35-
d = x.blocksize
36-
Blocks(d[2], d[1])
37-
end
38-
function adjoint(x::Blocks{1})
39-
d = x.blocksize
40-
Blocks(1, d[1])
41-
end
42-
43-
function _ctranspose(x::AbstractArray)
44-
Any[Dagger.@spawn(adjoint(x[j,i])) for i=1:size(x,2), j=1:size(x,1)]
45-
end
46-
47-
function stage(ctx::Context, node::Transpose)
48-
inp = stage(ctx, node.input)
49-
thunks = _ctranspose(chunks(inp))
50-
return DArray(eltype(inp), domain(inp)', domainchunks(inp)', thunks, inp.partitioning', inp.concat)
51-
end
19+
# Matrix-(Matrix/Vector) multiply
5220

5321
import Base: *, +
5422

test/array.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using LinearAlgebra, SparseArrays, Random, SharedArrays
2-
import Dagger: chunks, DArray, domainchunks, treereduce_nd
2+
import Dagger: DArray, chunks, domainchunks, treereduce_nd
33
import Distributed: myid, procs
44
import Statistics: mean, var, std
55
import OnlineStats
@@ -138,7 +138,7 @@ end
138138

139139
@testset "distributing an array" begin
140140
function test_dist(X)
141-
X1 = Distribute(Blocks(10, 20), X)
141+
X1 = distribute(X, Blocks(10, 20))
142142
Xc = fetch(X1)
143143
@test Xc isa DArray{eltype(X),ndims(X)}
144144
@test Xc == X
@@ -147,7 +147,7 @@ end
147147
@test map(x->size(x) == (10, 20), domainchunks(Xc)) |> all
148148
end
149149
x = [1 2; 3 4]
150-
@test Distribute(Blocks(1,1), x) == x
150+
@test distribute(x, Blocks(1,1)) == x
151151
test_dist(rand(100, 100))
152152
test_dist(sprand(100, 100, 0.1))
153153

@@ -158,7 +158,7 @@ end
158158
@testset "transpose" begin
159159
function test_transpose(X)
160160
x, y = size(X)
161-
X1 = Distribute(Blocks(10, 20), X)
161+
X1 = distribute(X, Blocks(10, 20))
162162
@test X1' == X'
163163
Xc = fetch(X1')
164164
@test chunks(Xc) |> size == (div(y, 20), div(x,10))
@@ -174,7 +174,7 @@ end
174174
@testset "matrix-matrix multiply" begin
175175
function test_mul(X)
176176
tol = 1e-12
177-
X1 = Distribute(Blocks(10, 20), X)
177+
X1 = distribute(X, Blocks(10, 20))
178178
@test_throws DimensionMismatch X1*X1
179179
X2 = X1'*X1
180180
X3 = X1*X1'
@@ -188,7 +188,7 @@ end
188188
test_mul(rand(40, 40))
189189

190190
x = rand(10,10)
191-
X = Distribute(Blocks(3,3), x)
191+
X = distribute(x, Blocks(3,3))
192192
y = rand(10)
193193
@test norm(collect(X*y) - x*y) < 1e-13
194194
end
@@ -202,24 +202,24 @@ end
202202

203203
@testset "concat" begin
204204
m = rand(75,75)
205-
x = Distribute(Blocks(10,20), m)
206-
y = Distribute(Blocks(10,10), m)
205+
x = distribute(m, Blocks(10,20))
206+
y = distribute(m, Blocks(10,10))
207207
@test hcat(m,m) == collect(hcat(x,x)) == collect(hcat(x,y))
208208
@test vcat(m,m) == collect(vcat(x,x))
209209
@test_throws DimensionMismatch vcat(x,y)
210210
end
211211

212212
@testset "scale" begin
213213
x = rand(10,10)
214-
X = Distribute(Blocks(3,3), x)
214+
X = distribute(x, Blocks(3,3))
215215
y = rand(10)
216216

217217
@test Diagonal(y)*x == collect(Diagonal(y)*X)
218218
end
219219

220220
@testset "Getindex" begin
221221
function test_getindex(x)
222-
X = Distribute(Blocks(3,3), x)
222+
X = distribute(x, Blocks(3,3))
223223
@test collect(X[3:8, 2:7]) == x[3:8, 2:7]
224224
ragged_idx = [1,2,9,7,6,2,4,5]
225225
@test collect(X[ragged_idx, 2:7]) == x[ragged_idx, 2:7]
@@ -248,7 +248,7 @@ end
248248

249249

250250
@testset "cleanup" begin
251-
X = Distribute(Blocks(10,10), rand(10,10))
251+
X = distribute(rand(10,10), Blocks(10,10))
252252
@test collect(sin.(X)) == collect(sin.(X))
253253
end
254254

@@ -269,7 +269,7 @@ end
269269
x=rand(10,10)
270270
y=copy(x)
271271
y[3:8, 2:7] .= 1.0
272-
X = Distribute(Blocks(3,3), x)
272+
X = distribute(x, Blocks(3,3))
273273
@test collect(setindex(X,1.0, 3:8, 2:7)) == y
274274
@test collect(X) == x
275275
end
@@ -292,7 +292,7 @@ end
292292
@test collect(sort(y)) == x
293293

294294
x = ones(10)
295-
y = Distribute(Blocks(3), x)
295+
y = distribute(x, Blocks(3))
296296
@test_broken map(x->length(collect(x)), sort(y).chunks) == [3,3,3,1]
297297
end
298298

0 commit comments

Comments
 (0)