19
19
20
20
import java .io .File ;
21
21
import java .io .FileInputStream ;
22
- import java .io .FileOutputStream ;
23
22
import java .io .IOException ;
23
+ import java .io .OutputStream ;
24
+ import java .nio .channels .FileChannel ;
25
+ import java .nio .channels .WritableByteChannel ;
24
26
import javax .annotation .Nullable ;
25
27
26
28
import scala .None$ ;
34
36
import org .slf4j .Logger ;
35
37
import org .slf4j .LoggerFactory ;
36
38
39
+ import org .apache .spark .api .shuffle .ShuffleMapOutputWriter ;
40
+ import org .apache .spark .api .shuffle .ShufflePartitionWriter ;
41
+ import org .apache .spark .api .shuffle .ShuffleWriteSupport ;
37
42
import org .apache .spark .internal .config .package$ ;
38
43
import org .apache .spark .Partitioner ;
39
44
import org .apache .spark .ShuffleDependency ;
@@ -82,6 +87,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
82
87
private final int shuffleId ;
83
88
private final int mapId ;
84
89
private final Serializer serializer ;
90
+ private final ShuffleWriteSupport shuffleWriteSupport ;
85
91
private final IndexShuffleBlockResolver shuffleBlockResolver ;
86
92
87
93
/** Array of file writers, one for each partition */
@@ -103,7 +109,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
103
109
BypassMergeSortShuffleHandle <K , V > handle ,
104
110
int mapId ,
105
111
SparkConf conf ,
106
- ShuffleWriteMetricsReporter writeMetrics ) {
112
+ ShuffleWriteMetricsReporter writeMetrics ,
113
+ ShuffleWriteSupport shuffleWriteSupport ) {
107
114
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
108
115
this .fileBufferSize = (int ) (long ) conf .get (package$ .MODULE$ .SHUFFLE_FILE_BUFFER_SIZE ()) * 1024 ;
109
116
this .transferToEnabled = conf .getBoolean ("spark.file.transferTo" , true );
@@ -116,57 +123,61 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
116
123
this .writeMetrics = writeMetrics ;
117
124
this .serializer = dep .serializer ();
118
125
this .shuffleBlockResolver = shuffleBlockResolver ;
126
+ this .shuffleWriteSupport = shuffleWriteSupport ;
119
127
}
120
128
121
129
@ Override
122
130
public void write (Iterator <Product2 <K , V >> records ) throws IOException {
123
131
assert (partitionWriters == null );
124
- if (!records .hasNext ()) {
125
- partitionLengths = new long [numPartitions ];
126
- shuffleBlockResolver .writeIndexFileAndCommit (shuffleId , mapId , partitionLengths , null );
127
- mapStatus = MapStatus$ .MODULE$ .apply (blockManager .shuffleServerId (), partitionLengths );
128
- return ;
129
- }
130
- final SerializerInstance serInstance = serializer .newInstance ();
131
- final long openStartTime = System .nanoTime ();
132
- partitionWriters = new DiskBlockObjectWriter [numPartitions ];
133
- partitionWriterSegments = new FileSegment [numPartitions ];
134
- for (int i = 0 ; i < numPartitions ; i ++) {
135
- final Tuple2 <TempShuffleBlockId , File > tempShuffleBlockIdPlusFile =
136
- blockManager .diskBlockManager ().createTempShuffleBlock ();
137
- final File file = tempShuffleBlockIdPlusFile ._2 ();
138
- final BlockId blockId = tempShuffleBlockIdPlusFile ._1 ();
139
- partitionWriters [i ] =
140
- blockManager .getDiskWriter (blockId , file , serInstance , fileBufferSize , writeMetrics );
141
- }
142
- // Creating the file to write to and creating a disk writer both involve interacting with
143
- // the disk, and can take a long time in aggregate when we open many files, so should be
144
- // included in the shuffle write time.
145
- writeMetrics .incWriteTime (System .nanoTime () - openStartTime );
146
-
147
- while (records .hasNext ()) {
148
- final Product2 <K , V > record = records .next ();
149
- final K key = record ._1 ();
150
- partitionWriters [partitioner .getPartition (key )].write (key , record ._2 ());
151
- }
132
+ ShuffleMapOutputWriter mapOutputWriter = shuffleWriteSupport
133
+ .createMapOutputWriter (shuffleId , mapId , numPartitions );
134
+ try {
135
+ if (!records .hasNext ()) {
136
+ partitionLengths = new long [numPartitions ];
137
+ mapOutputWriter .commitAllPartitions ();
138
+ mapStatus = MapStatus$ .MODULE$ .apply (blockManager .shuffleServerId (), partitionLengths );
139
+ return ;
140
+ }
141
+ final SerializerInstance serInstance = serializer .newInstance ();
142
+ final long openStartTime = System .nanoTime ();
143
+ partitionWriters = new DiskBlockObjectWriter [numPartitions ];
144
+ partitionWriterSegments = new FileSegment [numPartitions ];
145
+ for (int i = 0 ; i < numPartitions ; i ++) {
146
+ final Tuple2 <TempShuffleBlockId , File > tempShuffleBlockIdPlusFile =
147
+ blockManager .diskBlockManager ().createTempShuffleBlock ();
148
+ final File file = tempShuffleBlockIdPlusFile ._2 ();
149
+ final BlockId blockId = tempShuffleBlockIdPlusFile ._1 ();
150
+ partitionWriters [i ] =
151
+ blockManager .getDiskWriter (blockId , file , serInstance , fileBufferSize , writeMetrics );
152
+ }
153
+ // Creating the file to write to and creating a disk writer both involve interacting with
154
+ // the disk, and can take a long time in aggregate when we open many files, so should be
155
+ // included in the shuffle write time.
156
+ writeMetrics .incWriteTime (System .nanoTime () - openStartTime );
152
157
153
- for (int i = 0 ; i < numPartitions ; i ++) {
154
- try (DiskBlockObjectWriter writer = partitionWriters [i ]) {
155
- partitionWriterSegments [i ] = writer .commitAndGet ();
158
+ while (records .hasNext ()) {
159
+ final Product2 <K , V > record = records .next ();
160
+ final K key = record ._1 ();
161
+ partitionWriters [partitioner .getPartition (key )].write (key , record ._2 ());
156
162
}
157
- }
158
163
159
- File output = shuffleBlockResolver .getDataFile (shuffleId , mapId );
160
- File tmp = Utils .tempFileWith (output );
161
- try {
162
- partitionLengths = writePartitionedFile (tmp );
163
- shuffleBlockResolver .writeIndexFileAndCommit (shuffleId , mapId , partitionLengths , tmp );
164
- } finally {
165
- if (tmp .exists () && !tmp .delete ()) {
166
- logger .error ("Error while deleting temp file {}" , tmp .getAbsolutePath ());
164
+ for (int i = 0 ; i < numPartitions ; i ++) {
165
+ try (DiskBlockObjectWriter writer = partitionWriters [i ]) {
166
+ partitionWriterSegments [i ] = writer .commitAndGet ();
167
+ }
167
168
}
169
+
170
+ partitionLengths = writePartitionedData (mapOutputWriter );
171
+ mapOutputWriter .commitAllPartitions ();
172
+ mapStatus = MapStatus$ .MODULE$ .apply (blockManager .shuffleServerId (), partitionLengths );
173
+ } catch (Exception e ) {
174
+ try {
175
+ mapOutputWriter .abort (e );
176
+ } catch (Exception e2 ) {
177
+ logger .error ("Failed to abort the writer after failing to write map output." , e2 );
178
+ }
179
+ throw e ;
168
180
}
169
- mapStatus = MapStatus$ .MODULE$ .apply (blockManager .shuffleServerId (), partitionLengths );
170
181
}
171
182
172
183
@ VisibleForTesting
@@ -179,37 +190,54 @@ long[] getPartitionLengths() {
179
190
*
180
191
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
181
192
*/
182
- private long [] writePartitionedFile ( File outputFile ) throws IOException {
193
+ private long [] writePartitionedData ( ShuffleMapOutputWriter mapOutputWriter ) throws IOException {
183
194
// Track location of the partition starts in the output file
184
195
final long [] lengths = new long [numPartitions ];
185
196
if (partitionWriters == null ) {
186
197
// We were passed an empty iterator
187
198
return lengths ;
188
199
}
189
-
190
- final FileOutputStream out = new FileOutputStream (outputFile , true );
191
200
final long writeStartTime = System .nanoTime ();
192
- boolean threwException = true ;
193
201
try {
194
202
for (int i = 0 ; i < numPartitions ; i ++) {
195
203
final File file = partitionWriterSegments [i ].file ();
196
- if ( file . exists ()) {
197
- final FileInputStream in = new FileInputStream ( file ) ;
198
- boolean copyThrewException = true ;
199
- try {
200
- lengths [ i ] = Utils . copyStream ( in , out , false , transferToEnabled );
204
+ boolean copyThrewException = true ;
205
+ ShufflePartitionWriter writer = null ;
206
+ try {
207
+ writer = mapOutputWriter . getNextPartitionWriter ();
208
+ if (! file . exists ()) {
201
209
copyThrewException = false ;
202
- } finally {
203
- Closeables .close (in , copyThrewException );
204
- }
205
- if (!file .delete ()) {
206
- logger .error ("Unable to delete file for partition {}" , i );
210
+ } else {
211
+ if (transferToEnabled ) {
212
+ WritableByteChannel outputChannel = writer .toChannel ();
213
+ FileInputStream in = new FileInputStream (file );
214
+ try (FileChannel inputChannel = in .getChannel ()) {
215
+ Utils .copyFileStreamNIO (inputChannel , outputChannel , 0 , inputChannel .size ());
216
+ copyThrewException = false ;
217
+ } finally {
218
+ Closeables .close (in , copyThrewException );
219
+ }
220
+ } else {
221
+ OutputStream tempOutputStream = writer .toStream ();
222
+ FileInputStream in = new FileInputStream (file );
223
+ try {
224
+ Utils .copyStream (in , tempOutputStream , false , false );
225
+ copyThrewException = false ;
226
+ } finally {
227
+ Closeables .close (in , copyThrewException );
228
+ }
229
+ }
230
+ if (!file .delete ()) {
231
+ logger .error ("Unable to delete file for partition {}" , i );
232
+ }
207
233
}
234
+ } finally {
235
+ Closeables .close (writer , copyThrewException );
208
236
}
237
+
238
+ lengths [i ] = writer .getNumBytesWritten ();
209
239
}
210
- threwException = false ;
211
240
} finally {
212
- Closeables .close (out , threwException );
213
241
writeMetrics .incWriteTime (System .nanoTime () - writeStartTime );
214
242
}
215
243
partitionWriters = null ;
0 commit comments