@@ -26,10 +26,11 @@ import scala.collection.mutable.ArrayBuffer
26
26
import com .google .common .io .ByteStreams
27
27
28
28
import org .apache .spark ._
29
+ import org .apache .spark .api .shuffle .{ShuffleMapOutputWriter , ShufflePartitionWriter }
29
30
import org .apache .spark .executor .ShuffleWriteMetrics
30
31
import org .apache .spark .internal .{config , Logging }
31
32
import org .apache .spark .serializer ._
32
- import org .apache .spark .storage .{BlockId , DiskBlockObjectWriter }
33
+ import org .apache .spark .storage .{BlockId , DiskBlockObjectWriter , ShuffleBlockId }
33
34
34
35
/**
35
36
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -670,11 +671,9 @@ private[spark] class ExternalSorter[K, V, C](
670
671
}
671
672
672
673
/**
673
- * Write all the data added into this ExternalSorter into a file in the disk store. This is
674
- * called by the SortShuffleWriter.
675
- *
676
- * @param blockId block ID to write to. The index file will be blockId.name + ".index".
677
- * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
674
+ * TODO remove this, as this is only used by UnsafeRowSerializerSuite in the SQL project.
675
+ * We should figure out an alternative way to test that so that we can remove this otherwise
676
+ * unused code path.
678
677
*/
679
678
def writePartitionedFile (
680
679
blockId : BlockId ,
@@ -718,6 +717,123 @@ private[spark] class ExternalSorter[K, V, C](
718
717
lengths
719
718
}
720
719
720
+ private def writeEmptyPartition (mapOutputWriter : ShuffleMapOutputWriter ): Unit = {
721
+ var partitionWriter : ShufflePartitionWriter = null
722
+ try {
723
+ partitionWriter = mapOutputWriter.getNextPartitionWriter
724
+ } finally {
725
+ if (partitionWriter != null ) {
726
+ partitionWriter.close()
727
+ }
728
+ }
729
+ }
730
+
731
+ /**
732
+ * Write all the data added into this ExternalSorter into a map output writer that pushes bytes
733
+ * to some arbitrary backing store. This is called by the SortShuffleWriter.
734
+ *
735
+ * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
736
+ */
737
+ def writePartitionedMapOutput (
738
+ shuffleId : Int , mapId : Int , mapOutputWriter : ShuffleMapOutputWriter ): Array [Long ] = {
739
+ // Track location of each range in the map output
740
+ val lengths = new Array [Long ](numPartitions)
741
+ var nextPartitionId = 0
742
+ if (spills.isEmpty) {
743
+ // Case where we only have in-memory data
744
+ val collection = if (aggregator.isDefined) map else buffer
745
+ val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
746
+ while (it.hasNext()) {
747
+ val partitionId = it.nextPartition()
748
+ // The contract for the plugin is that we will ask for a writer for every partition
749
+ // even if it's empty. However, the external sorter will return non-contiguous
750
+ // partition ids. So this loop "backfills" the empty partitions that form the gaps.
751
+
752
+ // The algorithm as a whole is correct because the partition ids are returned by the
753
+ // iterator in ascending order.
754
+ for (emptyPartition <- nextPartitionId until partitionId) {
755
+ writeEmptyPartition(mapOutputWriter)
756
+ }
757
+ var partitionWriter : ShufflePartitionWriter = null
758
+ var partitionPairsWriter : ShufflePartitionPairsWriter = null
759
+ try {
760
+ partitionWriter = mapOutputWriter.getNextPartitionWriter
761
+ val blockId = ShuffleBlockId (shuffleId, mapId, partitionId)
762
+ partitionPairsWriter = new ShufflePartitionPairsWriter (
763
+ partitionWriter,
764
+ serializerManager,
765
+ serInstance,
766
+ blockId,
767
+ context.taskMetrics().shuffleWriteMetrics)
768
+ while (it.hasNext && it.nextPartition() == partitionId) {
769
+ it.writeNext(partitionPairsWriter)
770
+ }
771
+ } finally {
772
+ if (partitionPairsWriter != null ) {
773
+ partitionPairsWriter.close()
774
+ }
775
+ if (partitionWriter != null ) {
776
+ partitionWriter.close()
777
+ }
778
+ }
779
+ if (partitionWriter != null ) {
780
+ lengths(partitionId) = partitionWriter.getNumBytesWritten
781
+ }
782
+ nextPartitionId = partitionId + 1
783
+ }
784
+ } else {
785
+ // We must perform merge-sort; get an iterator by partition and write everything directly.
786
+ for ((id, elements) <- this .partitionedIterator) {
787
+ // The contract for the plugin is that we will ask for a writer for every partition
788
+ // even if it's empty. However, the external sorter will return non-contiguous
789
+ // partition ids. So this loop "backfills" the empty partitions that form the gaps.
790
+
791
+ // The algorithm as a whole is correct because the partition ids are returned by the
792
+ // iterator in ascending order.
793
+ for (emptyPartition <- nextPartitionId until id) {
794
+ writeEmptyPartition(mapOutputWriter)
795
+ }
796
+ val blockId = ShuffleBlockId (shuffleId, mapId, id)
797
+ var partitionWriter : ShufflePartitionWriter = null
798
+ var partitionPairsWriter : ShufflePartitionPairsWriter = null
799
+ try {
800
+ partitionWriter = mapOutputWriter.getNextPartitionWriter
801
+ partitionPairsWriter = new ShufflePartitionPairsWriter (
802
+ partitionWriter,
803
+ serializerManager,
804
+ serInstance,
805
+ blockId,
806
+ context.taskMetrics().shuffleWriteMetrics)
807
+ if (elements.hasNext) {
808
+ for (elem <- elements) {
809
+ partitionPairsWriter.write(elem._1, elem._2)
810
+ }
811
+ }
812
+ } finally {
813
+ if (partitionPairsWriter!= null ) {
814
+ partitionPairsWriter.close()
815
+ }
816
+ }
817
+ if (partitionWriter != null ) {
818
+ lengths(id) = partitionWriter.getNumBytesWritten
819
+ }
820
+ nextPartitionId = id + 1
821
+ }
822
+ }
823
+
824
+ // The iterator may have stopped short of opening a writer for every partition. So fill in the
825
+ // remaining empty partitions.
826
+ for (emptyPartition <- nextPartitionId until numPartitions) {
827
+ writeEmptyPartition(mapOutputWriter)
828
+ }
829
+
830
+ context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
831
+ context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
832
+ context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
833
+
834
+ lengths
835
+ }
836
+
721
837
def stop (): Unit = {
722
838
spills.foreach(s => s.file.delete())
723
839
spills.clear()
@@ -781,7 +897,7 @@ private[spark] class ExternalSorter[K, V, C](
781
897
val inMemoryIterator = new WritablePartitionedIterator {
782
898
private [this ] var cur = if (upstream.hasNext) upstream.next() else null
783
899
784
- def writeNext (writer : DiskBlockObjectWriter ): Unit = {
900
+ def writeNext (writer : PairsWriter ): Unit = {
785
901
writer.write(cur._1._2, cur._2)
786
902
cur = if (upstream.hasNext) upstream.next() else null
787
903
}
0 commit comments