@@ -40,47 +40,46 @@ function Base.deserialize{T<:DArray}(S::AbstractSerializer, t::Type{T})
40
40
end
41
41
42
42
# Serialize only those parts of the object as required by the destination worker.
43
- type PartitionedSerializer
44
- indexable_obj # An indexable object, Array, SparseMatrix, etc.
45
- # Complete object on the serializing side.
46
- # Part object on the deserialized side.
47
- pids:: Nullable{Array}
48
- idxs:: Nullable{Array}
49
- local_idxs:: Nullable{Tuple}
43
+ type DestinationSerializer
44
+ generate:: Nullable{Function} # Function to generate the part to be serialized
45
+ pids:: Nullable{Array} # MUST have the same shape as the distribution
50
46
51
- PartitionedSerializer (obj, local_idxs:: Tuple ) = new (obj, Nullable {Array} (), Nullable {Array} (), local_idxs)
52
- function PartitionedSerializer (obj, pids:: Array , idxs:: Array )
53
- pas = new (obj,pids,idxs,Nullable {Tuple} ())
47
+ deser_obj:: Nullable{Any} # Deserialized part
54
48
55
- if myid () in pids
56
- pas. local_idxs = idxs[findfirst (pids, myid ())]
57
- end
58
- return pas
59
- end
49
+ DestinationSerializer (f,p,d) = new (f,p,d)
60
50
end
61
51
62
- function Base. serialize (S:: AbstractSerializer , pas:: PartitionedSerializer )
52
+ DestinationSerializer (f:: Function , pids:: Array ) = DestinationSerializer (f, pids, Nullable {Any} ())
53
+
54
+ # contructs a DestinationSerializer after verifying that the shape of pids.
55
+ function verified_destination_serializer (f:: Function , pids:: Array , verify_size)
56
+ @assert size (pids) == verify_size
57
+ return DestinationSerializer (f, pids)
58
+ end
59
+
60
+ DestinationSerializer (deser_obj:: Any ) = DestinationSerializer (Nullable {Function} (), Nullable {Array} (), deser_obj)
61
+
62
+ function Base. serialize (S:: AbstractSerializer , s:: DestinationSerializer )
63
63
pid = Base. worker_id_from_socket (S. io)
64
- I = get (pas. idxs)[findfirst (get (pas. pids), pid)]
65
- Serializer. serialize_type (S, typeof (pas))
66
- serialize (S, pas. indexable_obj[I... ])
67
- serialize (S, I)
64
+ pididx = findfirst (get (s. pids), pid)
65
+ Serializer. serialize_type (S, typeof (s))
66
+ serialize (S, get (s. generate)(pididx))
68
67
end
69
68
70
- function Base. deserialize {T<:PartitionedSerializer} (S:: AbstractSerializer , t:: Type{T} )
71
- obj_part = deserialize (S)
72
- I = deserialize (S)
73
- return PartitionedSerializer (obj_part, I)
69
+ function Base. deserialize {T<:DestinationSerializer} (S:: AbstractSerializer , t:: Type{T} )
70
+ lpart = deserialize (S)
71
+ return DestinationSerializer (lpart)
74
72
end
75
73
76
- function verify_and_get (pas:: PartitionedSerializer , I)
77
- # Handle the special case where myid() is part of pas.pids.
78
- # For this case serialize/deserialize is not called as the remotecall is executed locally
79
- if myid () in get (pas. pids, [])
80
- @assert I == get (pas. idxs)[findfirst (get (pas. pids),myid ())]
81
- return pas. indexable_obj[I... ]
74
+
75
+ function localpart (s:: DestinationSerializer )
76
+ if ! isnull (s. deser_obj)
77
+ return get (s. deser_obj)
78
+ elseif ! isnull (s. generate) && (myid () in get (s. pids))
79
+ # Handle the special case where myid() is part of s.pids.
80
+ # In this case serialize/deserialize is not called as the remotecall is executed locally
81
+ return get (s. generate)(findfirst (get (s. pids), myid ()))
82
82
else
83
- @assert I == get (pas. local_idxs, ())
84
- return pas. indexable_obj
83
+ throw (ErrorException (string (" Invalid state in DestinationSerializer." )))
85
84
end
86
85
end
0 commit comments