Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 6adf67d

Browse files
arayHyukjinKwon
authored andcommitted
[SPARK-21985][PYSPARK] PairDeserializer is broken for double-zipped RDDs
## What changes were proposed in this pull request? (edited) Fixes a bug introduced in apache#16121 In PairDeserializer convert each batch of keys and values to lists (if they do not have `__len__` already) so that we can check that they are the same size. Normally they already are lists so this should not have a performance impact, but this is needed when repeated `zip`'s are done. ## How was this patch tested? Additional unit test Author: Andrew Ray <[email protected]> Closes apache#19226 from aray/SPARK-21985.
1 parent f407302 commit 6adf67d

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

python/pyspark/serializers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def load_stream(self, stream):
9797

9898
def _load_stream_without_unbatching(self, stream):
9999
"""
100-
Return an iterator of deserialized batches (lists) of objects from the input stream.
100+
Return an iterator of deserialized batches (iterable) of objects from the input stream.
101101
if the serializer does not operate on batches the default implementation returns an
102102
iterator of single element lists.
103103
"""
@@ -343,6 +343,10 @@ def _load_stream_without_unbatching(self, stream):
343343
key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
344344
val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
345345
for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
346+
# For double-zipped RDDs, the batches can be iterators from other PairDeserializer,
347+
# instead of lists. We need to convert them to lists if needed.
348+
key_batch = key_batch if hasattr(key_batch, '__len__') else list(key_batch)
349+
val_batch = val_batch if hasattr(val_batch, '__len__') else list(val_batch)
346350
if len(key_batch) != len(val_batch):
347351
raise ValueError("Can not deserialize PairRDD with different number of items"
348352
" in batches: (%d, %d)" % (len(key_batch), len(val_batch)))

python/pyspark/tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,18 @@ def test_cartesian_chaining(self):
644644
set([(x, (y, y)) for x in range(10) for y in range(10)])
645645
)
646646

647+
def test_zip_chaining(self):
648+
# Tests for SPARK-21985
649+
rdd = self.sc.parallelize('abc', 2)
650+
self.assertSetEqual(
651+
set(rdd.zip(rdd).zip(rdd).collect()),
652+
set([((x, x), x) for x in 'abc'])
653+
)
654+
self.assertSetEqual(
655+
set(rdd.zip(rdd.zip(rdd)).collect()),
656+
set([(x, (x, x)) for x in 'abc'])
657+
)
658+
647659
def test_deleting_input_files(self):
648660
# Regression test for SPARK-1025
649661
tempFile = tempfile.NamedTemporaryFile(delete=False)

0 commit comments

Comments
 (0)