Skip to content

Commit d68911b

Browse files
authored
Merge pull request #101 from JuliaParallel/amitm/genpartser
reworked PartionedSerializer (now DestinationSerializer)
2 parents 3097509 + fc0c137 commit d68911b

File tree

3 files changed

+51
-42
lines changed

3 files changed

+51
-42
lines changed

src/core.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -443,10 +443,14 @@ Convert a local array to distributed.
443443
function distribute(A::AbstractArray;
444444
procs = workers()[1:min(nworkers(), maximum(size(A)))],
445445
dist = defaultdist(size(A), procs))
446+
np = prod(dist)
447+
procs_used = procs[1:np]
446448
idxs, _ = chunk_idxs([size(A)...], dist)
447449

448-
pas = PartitionedSerializer(A, procs, idxs)
449-
return DArray(I->verify_and_get(pas, I), size(A), procs, dist)
450+
s = verified_destination_serializer(reshape(procs_used, size(idxs)), size(idxs)) do pididx
451+
A[idxs[pididx]...]
452+
end
453+
return DArray(I->localpart(s), size(A), procs_used, dist)
450454
end
451455

452456
"""
@@ -458,8 +462,10 @@ Distribute a local array `A` like the distributed array `DA`.
458462
function distribute(A::AbstractArray, DA::DArray)
459463
size(DA) == size(A) || throw(DimensionMismatch("Distributed array has size $(size(DA)) but array has $(size(A))"))
460464

461-
pas = PartitionedSerializer(A, procs(DA), DA.indexes)
462-
return DArray(I->verify_and_get(pas, I), DA)
465+
s = verified_destination_serializer(procs(DA), size(DA.indexes)) do pididx
466+
A[DA.indexes[pididx]...]
467+
end
468+
return DArray(I->localpart(s), DA)
463469
end
464470

465471
Base.convert{T,N,S<:AbstractArray}(::Type{DArray{T,N,S}}, A::S) = distribute(convert(AbstractArray{T,N}, A))
@@ -589,7 +595,7 @@ indexin_mask(a, b) = [i in b for i in a]
589595
import Base: tail
590596
# Given a tuple of indices and a tuple of masks, restrict the indices to the
591597
# valid regions. This is, effectively, reversing Base.setindex_shape_check.
592-
# We can't just use indexing into MergedIndices here because getindex is much
598+
# We can't just use indexing into MergedIndices here because getindex is much
593599
# pickier about singleton dimensions than setindex! is.
594600
restrict_indices(::Tuple{}, ::Tuple{}) = ()
595601
function restrict_indices(a::Tuple{Any, Vararg{Any}}, b::Tuple{Any, Vararg{Any}})
@@ -639,7 +645,7 @@ end
639645
Base.size(M::MergedIndices) = M.sz
640646
Base.@propagate_inbounds Base.getindex{_,N}(M::MergedIndices{_,N}, I::Vararg{Int, N}) =
641647
CartesianIndex(map(propagate_getindex, M.indices, I))
642-
# Additionally, we optimize bounds checking when using MergedIndices as an
648+
# Additionally, we optimize bounds checking when using MergedIndices as an
643649
# array index since checking, e.g., A[1:500, 1:500] is *way* faster than
644650
# checking an array of 500^2 elements of CartesianIndex{2}. This optimization
645651
# also applies to reshapes of MergedIndices since the outer shape of the

src/mapreduce.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,16 +162,20 @@ map_localparts(f::Callable, d1::DArray, d2::DArray) = DArray(d1) do I
162162
end
163163

164164
function map_localparts(f::Callable, DA::DArray, A::Array)
165-
pas = PartitionedSerializer(A, procs(DA), DA.indexes)
165+
s = verified_destination_serializer(procs(DA), size(DA.indexes)) do pididx
166+
A[DA.indexes[pididx]...]
167+
end
166168
DArray(DA) do I
167-
f(localpart(DA), verify_and_get(pas, I))
169+
f(localpart(DA), localpart(s))
168170
end
169171
end
170172

171173
function map_localparts(f::Callable, A::Array, DA::DArray)
172-
pas = PartitionedSerializer(A, procs(DA), DA.indexes)
174+
s = verified_destination_serializer(procs(DA), size(DA.indexes)) do pididx
175+
A[DA.indexes[pididx]...]
176+
end
173177
DArray(DA) do I
174-
f(verify_and_get(pas, I), localpart(DA))
178+
f(localpart(s), localpart(DA))
175179
end
176180
end
177181

src/serialize.jl

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -40,47 +40,46 @@ function Base.deserialize{T<:DArray}(S::AbstractSerializer, t::Type{T})
4040
end
4141

4242
# Serialize only those parts of the object as required by the destination worker.
43-
type PartitionedSerializer
44-
indexable_obj # An indexable object, Array, SparseMatrix, etc.
45-
# Complete object on the serializing side.
46-
# Part object on the deserialized side.
47-
pids::Nullable{Array}
48-
idxs::Nullable{Array}
49-
local_idxs::Nullable{Tuple}
43+
type DestinationSerializer
44+
generate::Nullable{Function} # Function to generate the part to be serialized
45+
pids::Nullable{Array} # MUST have the same shape as the distribution
5046

51-
PartitionedSerializer(obj, local_idxs::Tuple) = new(obj, Nullable{Array}(), Nullable{Array}(), local_idxs)
52-
function PartitionedSerializer(obj, pids::Array, idxs::Array)
53-
pas = new(obj,pids,idxs,Nullable{Tuple}())
47+
deser_obj::Nullable{Any} # Deserialized part
5448

55-
if myid() in pids
56-
pas.local_idxs = idxs[findfirst(pids, myid())]
57-
end
58-
return pas
59-
end
49+
DestinationSerializer(f,p,d) = new(f,p,d)
6050
end
6151

62-
function Base.serialize(S::AbstractSerializer, pas::PartitionedSerializer)
52+
DestinationSerializer(f::Function, pids::Array) = DestinationSerializer(f, pids, Nullable{Any}())
53+
54+
# contructs a DestinationSerializer after verifying that the shape of pids.
55+
function verified_destination_serializer(f::Function, pids::Array, verify_size)
56+
@assert size(pids) == verify_size
57+
return DestinationSerializer(f, pids)
58+
end
59+
60+
DestinationSerializer(deser_obj::Any) = DestinationSerializer(Nullable{Function}(), Nullable{Array}(), deser_obj)
61+
62+
function Base.serialize(S::AbstractSerializer, s::DestinationSerializer)
6363
pid = Base.worker_id_from_socket(S.io)
64-
I = get(pas.idxs)[findfirst(get(pas.pids), pid)]
65-
Serializer.serialize_type(S, typeof(pas))
66-
serialize(S, pas.indexable_obj[I...])
67-
serialize(S, I)
64+
pididx = findfirst(get(s.pids), pid)
65+
Serializer.serialize_type(S, typeof(s))
66+
serialize(S, get(s.generate)(pididx))
6867
end
6968

70-
function Base.deserialize{T<:PartitionedSerializer}(S::AbstractSerializer, t::Type{T})
71-
obj_part = deserialize(S)
72-
I = deserialize(S)
73-
return PartitionedSerializer(obj_part, I)
69+
function Base.deserialize{T<:DestinationSerializer}(S::AbstractSerializer, t::Type{T})
70+
lpart = deserialize(S)
71+
return DestinationSerializer(lpart)
7472
end
7573

76-
function verify_and_get(pas::PartitionedSerializer, I)
77-
# Handle the special case where myid() is part of pas.pids.
78-
# For this case serialize/deserialize is not called as the remotecall is executed locally
79-
if myid() in get(pas.pids, [])
80-
@assert I == get(pas.idxs)[findfirst(get(pas.pids),myid())]
81-
return pas.indexable_obj[I...]
74+
75+
function localpart(s::DestinationSerializer)
76+
if !isnull(s.deser_obj)
77+
return get(s.deser_obj)
78+
elseif !isnull(s.generate) && (myid() in get(s.pids))
79+
# Handle the special case where myid() is part of s.pids.
80+
# In this case serialize/deserialize is not called as the remotecall is executed locally
81+
return get(s.generate)(findfirst(get(s.pids), myid()))
8282
else
83-
@assert I == get(pas.local_idxs, ())
84-
return pas.indexable_obj
83+
throw(ErrorException(string("Invalid state in DestinationSerializer.")))
8584
end
8685
end

0 commit comments

Comments
 (0)