@@ -26,10 +26,11 @@ import scala.collection.mutable.ArrayBuffer
26
26
import com .google .common .io .ByteStreams
27
27
28
28
import org .apache .spark ._
29
+ import org .apache .spark .api .shuffle .{ShuffleMapOutputWriter , ShufflePartitionWriter }
29
30
import org .apache .spark .executor .ShuffleWriteMetrics
30
31
import org .apache .spark .internal .{config , Logging }
31
32
import org .apache .spark .serializer ._
32
- import org .apache .spark .storage .{BlockId , DiskBlockObjectWriter }
33
+ import org .apache .spark .storage .{BlockId , DiskBlockObjectWriter , ShuffleBlockId }
33
34
34
35
/**
35
36
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -674,11 +675,9 @@ private[spark] class ExternalSorter[K, V, C](
674
675
}
675
676
676
677
/**
677
- * Write all the data added into this ExternalSorter into a file in the disk store. This is
678
- * called by the SortShuffleWriter.
679
- *
680
- * @param blockId block ID to write to. The index file will be blockId.name + ".index".
681
- * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
678
+ * TODO remove this, as this is only used by UnsafeRowSerializerSuite in the SQL project.
679
+ * We should figure out an alternative way to test that so that we can remove this otherwise
680
+ * unused code path.
682
681
*/
683
682
def writePartitionedFile (
684
683
blockId : BlockId ,
@@ -722,6 +721,123 @@ private[spark] class ExternalSorter[K, V, C](
722
721
lengths
723
722
}
724
723
724
+ private def writeEmptyPartition (mapOutputWriter : ShuffleMapOutputWriter ): Unit = {
725
+ var partitionWriter : ShufflePartitionWriter = null
726
+ try {
727
+ partitionWriter = mapOutputWriter.getNextPartitionWriter
728
+ } finally {
729
+ if (partitionWriter != null ) {
730
+ partitionWriter.close()
731
+ }
732
+ }
733
+ }
734
+
735
+ /**
736
+ * Write all the data added into this ExternalSorter into a map output writer that pushes bytes
737
+ * to some arbitrary backing store. This is called by the SortShuffleWriter.
738
+ *
739
+ * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
740
+ */
741
+ def writePartitionedMapOutput (
742
+ shuffleId : Int , mapId : Int , mapOutputWriter : ShuffleMapOutputWriter ): Array [Long ] = {
743
+ // Track location of each range in the map output
744
+ val lengths = new Array [Long ](numPartitions)
745
+ var nextPartitionId = 0
746
+ if (spills.isEmpty) {
747
+ // Case where we only have in-memory data
748
+ val collection = if (aggregator.isDefined) map else buffer
749
+ val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
750
+ while (it.hasNext()) {
751
+ val partitionId = it.nextPartition()
752
+ // The contract for the plugin is that we will ask for a writer for every partition
753
+ // even if it's empty. However, the external sorter will return non-contiguous
754
+ // partition ids. So this loop "backfills" the empty partitions that form the gaps.
755
+
756
+ // The algorithm as a whole is correct because the partition ids are returned by the
757
+ // iterator in ascending order.
758
+ for (emptyPartition <- nextPartitionId until partitionId) {
759
+ writeEmptyPartition(mapOutputWriter)
760
+ }
761
+ var partitionWriter : ShufflePartitionWriter = null
762
+ var partitionPairsWriter : ShufflePartitionPairsWriter = null
763
+ try {
764
+ partitionWriter = mapOutputWriter.getNextPartitionWriter
765
+ val blockId = ShuffleBlockId (shuffleId, mapId, partitionId)
766
+ partitionPairsWriter = new ShufflePartitionPairsWriter (
767
+ partitionWriter,
768
+ serializerManager,
769
+ serInstance,
770
+ blockId,
771
+ context.taskMetrics().shuffleWriteMetrics)
772
+ while (it.hasNext && it.nextPartition() == partitionId) {
773
+ it.writeNext(partitionPairsWriter)
774
+ }
775
+ } finally {
776
+ if (partitionPairsWriter != null ) {
777
+ partitionPairsWriter.close()
778
+ }
779
+ if (partitionWriter != null ) {
780
+ partitionWriter.close()
781
+ }
782
+ }
783
+ if (partitionWriter != null ) {
784
+ lengths(partitionId) = partitionWriter.getNumBytesWritten
785
+ }
786
+ nextPartitionId = partitionId + 1
787
+ }
788
+ } else {
789
+ // We must perform merge-sort; get an iterator by partition and write everything directly.
790
+ for ((id, elements) <- this .partitionedIterator) {
791
+ // The contract for the plugin is that we will ask for a writer for every partition
792
+ // even if it's empty. However, the external sorter will return non-contiguous
793
+ // partition ids. So this loop "backfills" the empty partitions that form the gaps.
794
+
795
+ // The algorithm as a whole is correct because the partition ids are returned by the
796
+ // iterator in ascending order.
797
+ for (emptyPartition <- nextPartitionId until id) {
798
+ writeEmptyPartition(mapOutputWriter)
799
+ }
800
+ val blockId = ShuffleBlockId (shuffleId, mapId, id)
801
+ var partitionWriter : ShufflePartitionWriter = null
802
+ var partitionPairsWriter : ShufflePartitionPairsWriter = null
803
+ try {
804
+ partitionWriter = mapOutputWriter.getNextPartitionWriter
805
+ partitionPairsWriter = new ShufflePartitionPairsWriter (
806
+ partitionWriter,
807
+ serializerManager,
808
+ serInstance,
809
+ blockId,
810
+ context.taskMetrics().shuffleWriteMetrics)
811
+ if (elements.hasNext) {
812
+ for (elem <- elements) {
813
+ partitionPairsWriter.write(elem._1, elem._2)
814
+ }
815
+ }
816
+ } finally {
817
+ if (partitionPairsWriter!= null ) {
818
+ partitionPairsWriter.close()
819
+ }
820
+ }
821
+ if (partitionWriter != null ) {
822
+ lengths(id) = partitionWriter.getNumBytesWritten
823
+ }
824
+ nextPartitionId = id + 1
825
+ }
826
+ }
827
+
828
+ // The iterator may have stopped short of opening a writer for every partition. So fill in the
829
+ // remaining empty partitions.
830
+ for (emptyPartition <- nextPartitionId until numPartitions) {
831
+ writeEmptyPartition(mapOutputWriter)
832
+ }
833
+
834
+ context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
835
+ context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
836
+ context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
837
+
838
+ lengths
839
+ }
840
+
725
841
def stop (): Unit = {
726
842
spills.foreach(s => s.file.delete())
727
843
spills.clear()
@@ -785,7 +901,7 @@ private[spark] class ExternalSorter[K, V, C](
785
901
val inMemoryIterator = new WritablePartitionedIterator {
786
902
private [this ] var cur = if (upstream.hasNext) upstream.next() else null
787
903
788
- def writeNext (writer : DiskBlockObjectWriter ): Unit = {
904
+ def writeNext (writer : PairsWriter ): Unit = {
789
905
writer.write(cur._1._2, cur._2)
790
906
cur = if (upstream.hasNext) upstream.next() else null
791
907
}
0 commit comments