Skip to content

Commit 50bdf1e

Browse files
committed
Allow arbitrary distribution of chunks in mapslices
1 parent 586e691 commit 50bdf1e

File tree

4 files changed

+22
-31
lines changed

4 files changed

+22
-31
lines changed

src/mapreduce.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,15 +237,15 @@ for f in (:-, :abs, :abs2, :acos, :acosd, :acosh, :acot, :acotd, :acoth,
237237
end
238238
end
239239

240-
function mapslices{T,N}(f::Function, D::DArray{T,N}, dims::AbstractVector)
241-
#Ensure that the complete DArray is available on the specified dims on all processors
242-
for d in dims
243-
for idxs in D.indexes
244-
if length(idxs[d]) != size(D, d)
245-
throw(DimensionMismatch(string("dimension $d is distributed. ",
246-
"mapslices requires dimension $d to be completely available on all processors.")))
247-
end
240+
function mapslices{T,N,A}(f::Function, D::DArray{T,N,A}, dims::AbstractVector)
241+
if !all(t -> t == 1, size(D.indexes)[dims])
242+
p = ones(Int, ndims(D))
243+
nondims = filter(t -> !(t in dims), 1:ndims(D))
244+
p[nondims] = defaultdist([size(D)...][[nondims...]], procs(D))
245+
DD = DArray(size(D), procs(D), p) do I
246+
return convert(A, D[I...])
248247
end
248+
return mapslices(f, DD, dims)
249249
end
250250

251251
refs = Future[remotecall((x,y,z)->mapslices(x,localpart(y),z), p, f, D, dims) for p in procs(D)]

test/REQUIRE

Lines changed: 0 additions & 1 deletion
This file was deleted.

test/darray.jl

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -655,30 +655,24 @@ end
655655

656656
check_leaks()
657657

658-
# The mapslices tests have been taken from Base.
659-
# Commented out tests that need to be enabled in due course when DArray support is more complete
660658
@testset "test mapslices" begin
661-
a = drand((5,5), workers(), [1, min(nworkers(), 5)])
662-
h = mapslices(v -> fit(Histogram,v,0:0.1:1).weights, a, 1)
663-
# H = mapslices(v -> hist(v,0:0.1:1)[2], a, 2)
664-
# s = mapslices(sort, a, [1])
665-
# S = mapslices(sort, a, [2])
666-
for i = 1:5
667-
@test h[:,i] == fit(Histogram, a[:,i],0:0.1:1).weights
668-
# @test vec(H[i,:]) => hist(vec(a[i,:]),0:0.1:1)[2]
669-
# @test s[:,i] => sort(a[:,i])
670-
# @test vec(S[i,:]) => sort(vec(a[i,:]))
671-
end
659+
A = randn(5,5,5)
660+
D = distribute(A, procs = workers(), dist = [1, 1, min(nworkers(), 5)])
661+
@test mapslices(svdvals, D, (1,2)) mapslices(svdvals, A, (1,2))
662+
@test mapslices(svdvals, D, (1,3)) mapslices(svdvals, A, (1,3))
663+
@test mapslices(svdvals, D, (2,3)) mapslices(svdvals, A, (2,3))
664+
@test mapslices(sort, D, (1,)) mapslices(sort, A, (1,))
665+
@test mapslices(sort, D, (2,)) mapslices(sort, A, (2,))
666+
@test mapslices(sort, D, (3,)) mapslices(sort, A, (3,))
672667

673668
# issue #3613
674-
b = mapslices(sum, dones(Float64, (2,3,4), workers(), [1,1,min(nworkers(),4)]), [1,2])
675-
@test size(b) == (1,1,4)
676-
@test all(b.==6)
669+
B = mapslices(sum, dones(Float64, (2,3,4), workers(), [1,1,min(nworkers(),4)]), [1,2])
670+
@test size(B) == (1,1,4)
671+
@test all(B.==6)
677672

678673
# issue #5141
679-
## Update Removed the version that removes the dimensions when dims==1:ndims(A)
680-
c1 = mapslices(x-> maximum(-x), a, [])
681-
# @test c1 => -a
674+
C1 = mapslices(x-> maximum(-x), D, [])
675+
@test C1 == -D
682676

683677
# issue #5177
684678
c = dones(Float64, (2,3,4,5), workers(), [1,1,1,min(nworkers(),5)])
@@ -695,7 +689,7 @@ check_leaks()
695689
n3a = mapslices(x-> ones(1,6), c, [2,3])
696690
@test (size(n1a) == (1,6,4,5) && size(n2a) == (1,3,6,5) && size(n3a) == (2,1,6,5))
697691
@test (size(n1) == (6,1,4,5) && size(n2) == (6,3,1,5) && size(n3) == (2,6,1,5))
698-
close(a)
692+
close(D)
699693
close(c)
700694
darray_closeall() # close the temporaries created above
701695
end

test/runtests.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ end
99
@assert nworkers() >= 3
1010

1111
using DistributedArrays
12-
using StatsBase # for fit(Histogram, ...)
13-
@everywhere using StatsBase # because exported functions are not exported on workers with using
1412

1513
@everywhere srand(1234 + myid())
1614

0 commit comments

Comments
 (0)