Skip to content

Commit d2e9e19

Browse files
committed
Rename PartionedSerializer to DestinationSerializer
Make the destination aware serializer more generic Fix bug w.r.t. pids having the correct shape
1 parent 3097509 commit d2e9e19

File tree

3 files changed

+49
-42
lines changed

3 files changed

+49
-42
lines changed

src/core.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,10 @@ function distribute(A::AbstractArray;
445445
dist = defaultdist(size(A), procs))
446446
idxs, _ = chunk_idxs([size(A)...], dist)
447447

448-
pas = PartitionedSerializer(A, procs, idxs)
449-
return DArray(I->verify_and_get(pas, I), size(A), procs, dist)
448+
s = verified_destination_serializer(reshape(procs, size(idxs)), size(idxs)) do pididx
449+
A[idxs[pididx]...]
450+
end
451+
return DArray(I->localpart(s), size(A), procs, dist)
450452
end
451453

452454
"""
@@ -458,8 +460,10 @@ Distribute a local array `A` like the distributed array `DA`.
458460
function distribute(A::AbstractArray, DA::DArray)
459461
size(DA) == size(A) || throw(DimensionMismatch("Distributed array has size $(size(DA)) but array has $(size(A))"))
460462

461-
pas = PartitionedSerializer(A, procs(DA), DA.indexes)
462-
return DArray(I->verify_and_get(pas, I), DA)
463+
s = verified_destination_serializer(procs(DA), size(DA.indexes)) do pididx
464+
A[DA.indexes[pididx]...]
465+
end
466+
return DArray(I->localpart(s), DA)
463467
end
464468

465469
Base.convert{T,N,S<:AbstractArray}(::Type{DArray{T,N,S}}, A::S) = distribute(convert(AbstractArray{T,N}, A))
@@ -589,7 +593,7 @@ indexin_mask(a, b) = [i in b for i in a]
589593
import Base: tail
590594
# Given a tuple of indices and a tuple of masks, restrict the indices to the
591595
# valid regions. This is, effectively, reversing Base.setindex_shape_check.
592-
# We can't just use indexing into MergedIndices here because getindex is much
596+
# We can't just use indexing into MergedIndices here because getindex is much
593597
# pickier about singleton dimensions than setindex! is.
594598
restrict_indices(::Tuple{}, ::Tuple{}) = ()
595599
function restrict_indices(a::Tuple{Any, Vararg{Any}}, b::Tuple{Any, Vararg{Any}})
@@ -639,7 +643,7 @@ end
639643
Base.size(M::MergedIndices) = M.sz
640644
Base.@propagate_inbounds Base.getindex{_,N}(M::MergedIndices{_,N}, I::Vararg{Int, N}) =
641645
CartesianIndex(map(propagate_getindex, M.indices, I))
642-
# Additionally, we optimize bounds checking when using MergedIndices as an
646+
# Additionally, we optimize bounds checking when using MergedIndices as an
643647
# array index since checking, e.g., A[1:500, 1:500] is *way* faster than
644648
# checking an array of 500^2 elements of CartesianIndex{2}. This optimization
645649
# also applies to reshapes of MergedIndices since the outer shape of the

src/mapreduce.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,16 +162,20 @@ map_localparts(f::Callable, d1::DArray, d2::DArray) = DArray(d1) do I
162162
end
163163

164164
function map_localparts(f::Callable, DA::DArray, A::Array)
165-
pas = PartitionedSerializer(A, procs(DA), DA.indexes)
165+
s = verified_destination_serializer(procs(DA), size(DA.indexes)) do pididx
166+
A[DA.indexes[pididx]...]
167+
end
166168
DArray(DA) do I
167-
f(localpart(DA), verify_and_get(pas, I))
169+
f(localpart(DA), localpart(s))
168170
end
169171
end
170172

171173
function map_localparts(f::Callable, A::Array, DA::DArray)
172-
pas = PartitionedSerializer(A, procs(DA), DA.indexes)
174+
s = verified_destination_serializer(procs(DA), size(DA.indexes)) do pididx
175+
A[DA.indexes[pididx]...]
176+
end
173177
DArray(DA) do I
174-
f(verify_and_get(pas, I), localpart(DA))
178+
f(localpart(s), localpart(DA))
175179
end
176180
end
177181

src/serialize.jl

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -40,47 +40,46 @@ function Base.deserialize{T<:DArray}(S::AbstractSerializer, t::Type{T})
4040
end
4141

4242
# Serialize only those parts of the object as required by the destination worker.
43-
type PartitionedSerializer
44-
indexable_obj # An indexable object, Array, SparseMatrix, etc.
45-
# Complete object on the serializing side.
46-
# Part object on the deserialized side.
47-
pids::Nullable{Array}
48-
idxs::Nullable{Array}
49-
local_idxs::Nullable{Tuple}
43+
type DestinationSerializer
44+
generate::Nullable{Function} # Function to generate the part to be serialized
45+
pids::Nullable{Array} # MUST have the same shape as the distribution
5046

51-
PartitionedSerializer(obj, local_idxs::Tuple) = new(obj, Nullable{Array}(), Nullable{Array}(), local_idxs)
52-
function PartitionedSerializer(obj, pids::Array, idxs::Array)
53-
pas = new(obj,pids,idxs,Nullable{Tuple}())
47+
deser_obj::Nullable{Any} # Deserialized part
5448

55-
if myid() in pids
56-
pas.local_idxs = idxs[findfirst(pids, myid())]
57-
end
58-
return pas
59-
end
49+
DestinationSerializer(f,p,d) = new(f,p,d)
6050
end
6151

62-
function Base.serialize(S::AbstractSerializer, pas::PartitionedSerializer)
52+
DestinationSerializer(f::Function, pids::Array) = DestinationSerializer(f, pids, Nullable{Any}())
53+
54+
# contructs a DestinationSerializer after verifying that the shape of pids.
55+
function verified_destination_serializer(f::Function, pids::Array, verify_size)
56+
@assert size(pids) == verify_size
57+
return DestinationSerializer(f, pids)
58+
end
59+
60+
DestinationSerializer(deser_obj::Any) = DestinationSerializer(Nullable{Function}(), Nullable{Array}(), deser_obj)
61+
62+
function Base.serialize(S::AbstractSerializer, s::DestinationSerializer)
6363
pid = Base.worker_id_from_socket(S.io)
64-
I = get(pas.idxs)[findfirst(get(pas.pids), pid)]
65-
Serializer.serialize_type(S, typeof(pas))
66-
serialize(S, pas.indexable_obj[I...])
67-
serialize(S, I)
64+
pididx = findfirst(get(s.pids), pid)
65+
Serializer.serialize_type(S, typeof(s))
66+
serialize(S, get(s.generate)(pididx))
6867
end
6968

70-
function Base.deserialize{T<:PartitionedSerializer}(S::AbstractSerializer, t::Type{T})
71-
obj_part = deserialize(S)
72-
I = deserialize(S)
73-
return PartitionedSerializer(obj_part, I)
69+
function Base.deserialize{T<:DestinationSerializer}(S::AbstractSerializer, t::Type{T})
70+
lpart = deserialize(S)
71+
return DestinationSerializer(lpart)
7472
end
7573

76-
function verify_and_get(pas::PartitionedSerializer, I)
77-
# Handle the special case where myid() is part of pas.pids.
78-
# For this case serialize/deserialize is not called as the remotecall is executed locally
79-
if myid() in get(pas.pids, [])
80-
@assert I == get(pas.idxs)[findfirst(get(pas.pids),myid())]
81-
return pas.indexable_obj[I...]
74+
75+
function localpart(s::DestinationSerializer)
76+
if !isnull(s.deser_obj)
77+
return get(s.deser_obj)
78+
elseif !isnull(s.generate) && (myid() in get(s.pids))
79+
# Handle the special case where myid() is part of s.pids.
80+
# In this case serialize/deserialize is not called as the remotecall is executed locally
81+
return get(s.generate)(findfirst(get(s.pids), myid()))
8282
else
83-
@assert I == get(pas.local_idxs, ())
84-
return pas.indexable_obj
83+
throw(ErrorException(string("Invalid state in DestinationSerializer.")))
8584
end
8685
end

0 commit comments

Comments
 (0)