From 3bdfdf1c724372f364fc7f4d2a09ba24ff6a13b8 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Tue, 13 Nov 2018 16:11:31 -0800 Subject: [PATCH 1/4] adding traits defining the API --- .../shuffle/external/ShuffleDataIO.scala | 25 ++++++++++++++++ .../external/ShufflePartitionReader.scala | 27 +++++++++++++++++ .../external/ShufflePartitionWriter.scala | 29 +++++++++++++++++++ .../shuffle/external/ShuffleReadSupport.scala | 24 +++++++++++++++ .../external/ShuffleWriteSupport.scala | 24 +++++++++++++++ 5 files changed, 129 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/external/ShuffleDataIO.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionReader.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/external/ShuffleReadSupport.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/external/ShuffleWriteSupport.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleDataIO.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleDataIO.scala new file mode 100644 index 0000000000000..68fb80c4e8010 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleDataIO.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external + +private[spark] + +trait ShuffleDataIO { + def writeSupport(): ShuffleWriteSupport + def readSupport(): ShuffleReadSupport +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionReader.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionReader.scala new file mode 100644 index 0000000000000..354f452090237 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionReader.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external + +import java.io.InputStream + +private[spark] + +// TODO: Support batch-fetch +trait ShufflePartitionReader { + def fetchPartition(reduceId: Int): InputStream +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala new file mode 100644 index 0000000000000..33b9f2f8428d2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external + +import java.io.{Closeable, InputStream} + +private[spark] + +trait ShufflePartitionWriter extends Closeable { + // Reduce ID == PartitionID ? + def appendPartition(partitionId: Int, partitionInput: InputStream): Unit + + def abort(exception: Throwable): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleReadSupport.scala new file mode 100644 index 0000000000000..01d54c9953aab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleReadSupport.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external + +private[spark] + +trait ShuffleReadSupport { + def newPartitionReader(appId: String, shuffleId: Int, mapId: Int): ShufflePartitionReader +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleWriteSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleWriteSupport.scala new file mode 100644 index 0000000000000..f985cc1853142 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleWriteSupport.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external + +private[spark] + +trait ShuffleWriteSupport { + def newPartitionWriter(appId: String, shuffleId: Int, mapId: Int): ShufflePartitionWriter +} From 3cece7788e0b643097ad3ff1f90a977828f0c000 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Wed, 21 Nov 2018 19:05:26 -0800 Subject: [PATCH 2/4] merge experiment work --- .../shuffle/ExternalShuffleBlockHandler.java | 43 ++++- .../shuffle/ExternalShuffleBlockResolver.java | 103 +++++++++++ .../shuffle/ExternalShuffleClient.java | 18 +- .../shuffle/FileWriterStreamCallback.java | 149 +++++++++++++++ .../spark/network/shuffle/ShuffleClient.java | 1 + .../mesos/MesosExternalShuffleClient.java | 6 +- .../protocol/BlockTransferMessage.java | 10 +- .../protocol/ExternalServiceHeartbeat.java | 49 +++++ .../protocol/{mesos => }/RegisterDriver.java | 70 +++---- .../RegisterExecutorForBackupsOnly.java | 71 +++++++ .../protocol/UploadShuffleFileStream.java | 88 +++++++++ .../UploadShuffleIndexFileStream.java | 88 +++++++++ .../mesos/ShuffleServiceHeartbeat.java | 53 ------ .../ExternalShuffleIntegrationSuite.java | 2 +- .../org/apache/spark/MapOutputTracker.scala | 173 ++++++++++++------ .../scala/org/apache/spark/SparkEnv.scala | 23 ++- .../spark/internal/config/package.scala | 5 + .../spark/network/BlockTransferService.scala | 4 +- .../netty/NettyBlockTransferService.scala | 1 + .../apache/spark/scheduler/MapStatus.scala | 28 +++ .../shuffle/BackingUpShuffleWriter.scala | 114 ++++++++++++ .../shuffle/BlockStoreShuffleReader.scala | 3 +- .../ExternalFallbackShuffleClient.scala | 67 +++++++ .../shuffle/IndexShuffleBlockResolver.scala | 2 +- .../ShuffleServiceAddressProvider.scala | 24 +-- ...ShuffleServiceAddressProviderFactory.scala | 25 +++ .../shuffle/sort/SortShuffleManager.scala | 48 ++++- .../apache/spark/storage/BlockManager.scala | 103 +++++++++-- .../apache/spark/storage/BlockManagerId.scala | 14 +- .../storage/ShuffleBlockFetcherIterator.scala | 19 +- .../org/apache/spark/DistributedSuite.scala | 9 +- .../apache/spark/MapOutputTrackerSuite.scala | 15 +- .../NettyBlockTransferSecuritySuite.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 6 +- .../BlockStoreShuffleReaderSuite.scala | 3 +- .../BlockManagerReplicationSuite.scala | 5 +- .../spark/storage/BlockManagerSuite.scala | 9 +- .../ShuffleBlockFetcherIteratorSuite.scala | 16 +- ...uffle.ShuffleServiceAddressProviderFactory | 1 + .../org/apache/spark/deploy/k8s/Config.scala | 22 +++ .../k8s/SparkKubernetesClientFactory.scala | 32 ++++ .../cluster/k8s/ExecutorPodsSnapshot.scala | 10 +- .../k8s/KubernetesClusterManager.scala | 37 +--- .../cluster/k8s/SparkPodStates.scala | 65 +++++++ ...ernetesShuffleServiceAddressProvider.scala | 146 +++++++++++++++ ...ShuffleServiceAddressProviderFactory.scala | 50 +++++ .../mesos/MesosExternalShuffleService.scala | 2 +- .../cluster/mesos/MesosClusterManager.scala | 5 + .../cluster/YarnClusterManager.scala | 4 + .../streaming/ReceivedBlockHandlerSuite.scala | 5 +- 50 files changed, 1585 insertions(+), 263 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExternalServiceHeartbeat.java rename common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/{mesos => }/RegisterDriver.java (50%) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorForBackupsOnly.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleFileStream.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexFileStream.java delete mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java create mode 100644 core/src/main/scala/org/apache/spark/shuffle/BackingUpShuffleWriter.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ExternalFallbackShuffleClient.scala rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala => core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala (60%) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala create mode 100644 resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.shuffle.ShuffleServiceAddressProviderFactory create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 098fa7974b87b..8f1107d8a796f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -35,6 +35,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; @@ -106,7 +107,7 @@ protected void handleMessage( } else if (msgObj instanceof RegisterExecutor) { final Timer.Context responseDelayContext = - metrics.registerExecutorRequestLatencyMillis.time(); + metrics.registerExecutorRequestLatencyMillis.time(); try { RegisterExecutor msg = (RegisterExecutor) msgObj; checkAuth(client, msg.appId); @@ -116,9 +117,49 @@ protected void handleMessage( responseDelayContext.stop(); } + } else if (msgObj instanceof RegisterExecutorForBackupsOnly) { + final Timer.Context responseDelayContext = + metrics.registerExecutorRequestLatencyMillis.time(); + try { + RegisterExecutorForBackupsOnly msg = (RegisterExecutorForBackupsOnly) msgObj; + checkAuth(client, msg.appId); + blockManager.registerExecutorForBackups(msg.appId, msg.execId, msg.shuffleManager); + callback.onSuccess(ByteBuffer.wrap(new byte[0])); + } finally { + responseDelayContext.stop(); + } + } else { throw new UnsupportedOperationException("Unexpected message: " + msgObj); } + + } + + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + BlockTransferMessage header = BlockTransferMessage.Decoder.fromByteBuffer(messageHeader); + if (header instanceof UploadShuffleFileStream) { + UploadShuffleFileStream msg = (UploadShuffleFileStream) header; + checkAuth(client, msg.appId); + return blockManager.openShuffleFileForBackup( + msg.appId, + msg.execId, + msg.shuffleId, + msg.mapId); + } else if (header instanceof UploadShuffleIndexFileStream) { + UploadShuffleIndexFileStream msg = (UploadShuffleIndexFileStream) header; + checkAuth(client, msg.appId); + return blockManager.openShuffleIndexFileForBackup( + msg.appId, + msg.execId, + msg.shuffleId, + msg.mapId); + } else { + throw new UnsupportedOperationException("Unexpected message header: " + header); + } } public MetricSet getAllMetrics() { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0b7a27402369d..6ee40b5dc0a6e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -18,7 +18,12 @@ package org.apache.spark.network.shuffle; import java.io.*; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; import java.util.*; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; @@ -44,6 +49,7 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.LevelDBProvider; import org.apache.spark.network.util.LevelDBProvider.StoreVersion; @@ -75,6 +81,8 @@ public class ExternalShuffleBlockResolver { @VisibleForTesting final ConcurrentMap executors; + private final ConcurrentMap backupExecutors; + /** * Caches index file information so that we can avoid open/close the index files * for each block fetch. @@ -95,6 +103,8 @@ public class ExternalShuffleBlockResolver { "org.apache.spark.shuffle.sort.SortShuffleManager", "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"); + private final File shuffleBackupsDir; + public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile) throws IOException { this(conf, registeredExecutorFile, Executors.newSingleThreadExecutor( @@ -131,6 +141,8 @@ public int weigh(File file, ShuffleIndexInformation indexInfo) { } else { executors = Maps.newConcurrentMap(); } + this.backupExecutors = Maps.newConcurrentMap(); + this.shuffleBackupsDir = Files.createTempDirectory("spark-shuffle-backups").toFile(); this.directoryCleaner = directoryCleaner; } @@ -144,6 +156,12 @@ public void registerExecutor( String execId, ExecutorShuffleInfo executorInfo) { AppExecId fullId = new AppExecId(appId, execId); + if (backupExecutors.containsKey(fullId)) { + throw new UnsupportedOperationException( + String.format( + "Executor %s cannot be registered for both primary shuffle management and backup" + + " shuffle management.", fullId)); + } logger.info("Registered executor {} with {}", fullId, executorInfo); if (!knownManagers.contains(executorInfo.shuffleManager)) { throw new UnsupportedOperationException( @@ -161,6 +179,84 @@ public void registerExecutor( executors.put(fullId, executorInfo); } + public void registerExecutorForBackups(String appId, String execId, String shuffleManager) { + AppExecId fullId = new AppExecId(appId, execId); + if (executors.containsKey(fullId)) { + throw new UnsupportedOperationException( + String.format( + "Executor %s cannot be registered for both primary shuffle management and backup" + + " shuffle management.", fullId)); + } + File executorBackupDir = Paths.get( + shuffleBackupsDir.getAbsolutePath(), appId, execId).toFile(); + if (!executorBackupDir.mkdirs()) { + throw new RuntimeException( + String.format( + "Failed to create directories for executor backup shuffle files at %s.", + executorBackupDir.getAbsolutePath())); + } + if (!knownManagers.contains(shuffleManager)) { + throw new UnsupportedOperationException( + String.format( + "Unsupported shuffle manager of executor: %s.", fullId)); + } + + ExecutorShuffleInfo backupShuffleInfo = new ExecutorShuffleInfo( + new String[] { executorBackupDir.getAbsolutePath() }, + 1, + shuffleManager); + logger.info("Registering executor {} with {} for backups.", fullId, backupShuffleInfo); + backupExecutors.put(fullId, backupShuffleInfo); + } + + public StreamCallbackWithID openShuffleFileForBackup( + String appId, String execId, int shuffleId, int mapId) { + return getFileWriterStreamCallback( + appId, + execId, + shuffleId, + mapId, + "data", + FileWriterStreamCallback.BackupFileType.DATA); + } + + public StreamCallbackWithID openShuffleIndexFileForBackup( + String appId, String execId, int shuffleId, int mapId) { + return getFileWriterStreamCallback( + appId, + execId, + shuffleId, + mapId, + "index", + FileWriterStreamCallback.BackupFileType.INDEX); + } + + private StreamCallbackWithID getFileWriterStreamCallback( + String appId, + String execId, + int shuffleId, + int mapId, + String extension, + FileWriterStreamCallback.BackupFileType backupFileType) { + AppExecId fullId = new AppExecId(appId, execId); + ExecutorShuffleInfo executor = backupExecutors.get(fullId); + if (executor == null) { + throw new RuntimeException( + String.format("Executor is not registered for shuffle file backups" + + " (appId=%s, execId=%s)", appId, execId)); + } + File backedUpFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0." + extension); + FileWriterStreamCallback streamCallback = new FileWriterStreamCallback( + fullId, + shuffleId, + mapId, + backedUpFile, + backupFileType); + streamCallback.open(); + return streamCallback; + } + /** * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions * about how the hash and sort based shuffles store their data. @@ -173,6 +269,13 @@ public ManagedBuffer getBlockData( int reduceId) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { + logger.info("application's shuffle data isn't in main file system, checking backups..." + + "app id: {}, executor id: {}, shuffle id: {}, map id: {}, reduce id: {}", + appId, execId, shuffleId, mapId, reduceId); + executor = backupExecutors.get(new AppExecId(appId, execId)); + } + if (executor == null) { + logger.warn("Executor is not registered (appId: {}, execId: {}", appId, execId); throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index e49e27ab5aa79..531a6c146e6a7 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -34,7 +34,9 @@ import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.RegisterDriver; import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.RegisterExecutorForBackupsOnly; import org.apache.spark.network.util.TransportConf; /** @@ -43,7 +45,7 @@ * BlockTransferService), which has the downside of losing the shuffle data if we lose the * executors. */ -public class ExternalShuffleClient extends ShuffleClient { +public class ExternalShuffleClient extends ShuffleClient{ private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); private final TransportConf conf; @@ -90,6 +92,7 @@ public void fetchBlocks( int port, String execId, String[] blockIds, + boolean isBackup, BlockFetchingListener listener, DownloadFileManager downloadFileManager) { checkInit(); @@ -145,6 +148,19 @@ public void registerWithShuffleServer( } } + public void registerWithShuffleServerForBackups( + String host, + int port, + String execId, + String shuffleManager) throws IOException, InterruptedException{ + checkInit(); + try (TransportClient client = clientFactory.createUnmanagedClient(host, port)) { + ByteBuffer registerMessage = new RegisterExecutorForBackupsOnly( + appId, execId, shuffleManager).toByteBuffer(); + client.sendRpcSync(registerMessage, registrationTimeoutMs); + } + } + @Override public void close() { checkInit(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java new file mode 100644 index 0000000000000..56ec8283ce735 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java @@ -0,0 +1,149 @@ +package org.apache.spark.network.shuffle; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; + +import org.apache.spark.network.client.StreamCallbackWithID; + +final class FileWriterStreamCallback implements StreamCallbackWithID { + + private static final Logger logger = LoggerFactory.getLogger(FileWriterStreamCallback.class); + + public enum BackupFileType { + DATA("shuffle-data"), + INDEX("shuffle-index"); + + private final String typeString; + + BackupFileType(String typeString) { + this.typeString = typeString; + } + + @Override + public String toString() { + return typeString; + } + } + private final ExternalShuffleBlockResolver.AppExecId fullExecId; + private final int shuffleId; + private final int mapId; + private final File file; + private final BackupFileType fileType; + private WritableByteChannel fileOutputChannel = null; + + FileWriterStreamCallback( + ExternalShuffleBlockResolver.AppExecId fullExecId, + int shuffleId, + int mapId, + File file, + BackupFileType fileType) { + this.fullExecId = fullExecId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.file = file; + this.fileType = fileType; + } + + public void open() { + logger.info( + "Opening {} for backup writing. File type: {}", file.getAbsolutePath(), fileType); + if (fileOutputChannel != null) { + throw new IllegalStateException( + String.format( + "File %s for is already open for writing (type: %s).", + file.getAbsolutePath(), + fileType)); + } + if (!file.exists()) { + try { + if (!file.getParentFile().isDirectory() && !file.getParentFile().mkdirs()) { + throw new IOException( + String.format( + "Failed to create shuffle file directory at" + + file.getParentFile().getAbsolutePath() + "(type: %s).", fileType)); + } + + if (!file.createNewFile()) { + throw new IOException( + String.format( + "Failed to create shuffle file (type: %s).", fileType)); + } + } catch (IOException e) { + throw new RuntimeException( + String.format( + "Failed to create shuffle file at %s for backup (type: %s).", + file.getAbsolutePath(), + fileType), + e); + } + } + try { + // TODO encryption + fileOutputChannel = Channels.newChannel(new FileOutputStream(file)); + } catch (FileNotFoundException e) { + throw new RuntimeException( + String.format( + "Failed to find file for writing at %s (type: %s).", + file.getAbsolutePath(), + fileType), + e); + } + } + + @Override + public String getID() { + return String.format("%s-%s-%d-%d-%s", + fullExecId.appId, + fullExecId.execId, + shuffleId, + mapId, + fileType); + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + verifyShuffleFileOpenForWriting(); + while (buf.hasRemaining()) { + fileOutputChannel.write(buf); + } + } + + @Override + public void onComplete(String streamId) throws IOException { + fileOutputChannel.close(); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + logger.warn("Failed to back up shuffle file at {} (type: %s).", + file.getAbsolutePath(), + fileType, + cause); + fileOutputChannel.close(); + // TODO delete parent dirs too + if (!file.delete()) { + logger.warn( + "Failed to delete incomplete backup shuffle file at %s (type: %s)", + file.getAbsolutePath(), + fileType); + } + } + + private void verifyShuffleFileOpenForWriting() { + if (fileOutputChannel == null) { + throw new RuntimeException( + String.format( + "Shuffle file at %s not open for writing (type: %s).", + file.getAbsolutePath(), + fileType)); + } + } +} \ No newline at end of file diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 62b99c40f61f9..5263e38d32a8e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -53,6 +53,7 @@ public abstract void fetchBlocks( int port, String execId, String[] blockIds, + boolean isBackup, BlockFetchingListener listener, DownloadFileManager downloadFileManager); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 60179f126bc44..8d3d86698974f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -24,7 +24,6 @@ import java.util.concurrent.TimeUnit; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,7 +31,8 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.shuffle.ExternalShuffleClient; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; +import org.apache.spark.network.shuffle.protocol.RegisterDriver; +import org.apache.spark.network.shuffle.protocol.ExternalServiceHeartbeat; import org.apache.spark.network.util.TransportConf; /** @@ -117,7 +117,7 @@ private Heartbeater(TransportClient client) { @Override public void run() { // TODO: Stop sending heartbeats if the shuffle service has lost the app due to timeout - client.send(new ShuffleServiceHeartbeat(appId).toByteBuffer()); + client.send(new ExternalServiceHeartbeat(appId).toByteBuffer()); } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index a68a297519b66..dfdbde0eac1c6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -23,8 +23,6 @@ import io.netty.buffer.Unpooled; import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; -import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; /** * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or @@ -42,7 +40,8 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), - HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6); + HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), UPLOAD_SHUFFLE_FILE_STREAM(7), UPLOAD_SHUFFLE_INDEX_STREAM(8), + REGISTER_EXECUTOR_FOR_BACKUPS(9); private final byte id; @@ -66,8 +65,11 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 2: return RegisterExecutor.decode(buf); case 3: return StreamHandle.decode(buf); case 4: return RegisterDriver.decode(buf); - case 5: return ShuffleServiceHeartbeat.decode(buf); + case 5: return ExternalServiceHeartbeat.decode(buf); case 6: return UploadBlockStream.decode(buf); + case 7: return UploadShuffleFileStream.decode(buf); + case 8: return UploadShuffleIndexFileStream.decode(buf); + case 9: return RegisterExecutorForBackupsOnly.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExternalServiceHeartbeat.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExternalServiceHeartbeat.java new file mode 100644 index 0000000000000..d4403e8b94aa9 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExternalServiceHeartbeat.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; + +/** + * A heartbeat sent from the driver to some ExternalService + */ +public class ExternalServiceHeartbeat extends BlockTransferMessage { + private final String appId; + + public ExternalServiceHeartbeat(String appId) { + this.appId = appId; + } + + public String getAppId() { return appId; } + + @Override + protected Type type() { return Type.HEARTBEAT; } + + @Override + public int encodedLength() { return Encoders.Strings.encodedLength(appId); } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + } + + public static ExternalServiceHeartbeat decode(ByteBuf buf) { + return new ExternalServiceHeartbeat(Encoders.Strings.decode(buf)); + } +} \ No newline at end of file diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java similarity index 50% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java index d5f53ccb7f741..1a3918a9b36aa 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle.protocol.mesos; +package org.apache.spark.network.shuffle.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; @@ -30,48 +30,48 @@ * A message sent from the driver to register with the MesosExternalShuffleService. */ public class RegisterDriver extends BlockTransferMessage { - private final String appId; - private final long heartbeatTimeoutMs; + private final String appId; + private final long heartbeatTimeoutMs; - public RegisterDriver(String appId, long heartbeatTimeoutMs) { - this.appId = appId; - this.heartbeatTimeoutMs = heartbeatTimeoutMs; - } + public RegisterDriver(String appId, long heartbeatTimeoutMs) { + this.appId = appId; + this.heartbeatTimeoutMs = heartbeatTimeoutMs; + } - public String getAppId() { return appId; } + public String getAppId() { return appId; } - public long getHeartbeatTimeoutMs() { return heartbeatTimeoutMs; } + public long getHeartbeatTimeoutMs() { return heartbeatTimeoutMs; } - @Override - protected Type type() { return Type.REGISTER_DRIVER; } + @Override + protected Type type() { return Type.REGISTER_DRIVER; } - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(appId) + Long.SIZE / Byte.SIZE; - } + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + Long.SIZE / Byte.SIZE; + } - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - buf.writeLong(heartbeatTimeoutMs); - } + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + buf.writeLong(heartbeatTimeoutMs); + } - @Override - public int hashCode() { - return Objects.hashCode(appId, heartbeatTimeoutMs); - } + @Override + public int hashCode() { + return Objects.hashCode(appId, heartbeatTimeoutMs); + } - @Override - public boolean equals(Object o) { - if (!(o instanceof RegisterDriver)) { - return false; + @Override + public boolean equals(Object o) { + if (!(o instanceof RegisterDriver)) { + return false; + } + return Objects.equal(appId, ((RegisterDriver) o).appId); } - return Objects.equal(appId, ((RegisterDriver) o).appId); - } - public static RegisterDriver decode(ByteBuf buf) { - String appId = Encoders.Strings.decode(buf); - long heartbeatTimeout = buf.readLong(); - return new RegisterDriver(appId, heartbeatTimeout); - } + public static RegisterDriver decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + long heartbeatTimeout = buf.readLong(); + return new RegisterDriver(appId, heartbeatTimeout); + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorForBackupsOnly.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorForBackupsOnly.java new file mode 100644 index 0000000000000..3986f9519f81d --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorForBackupsOnly.java @@ -0,0 +1,71 @@ +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +public class RegisterExecutorForBackupsOnly extends BlockTransferMessage { + + public final String appId; + public final String execId; + public final String shuffleManager; + + public RegisterExecutorForBackupsOnly( + String appId, String execId, String shuffleManager) { + this.appId = appId; + this.execId = execId; + this.shuffleManager = shuffleManager; + } + + @Override + protected Type type() { + return Type.REGISTER_EXECUTOR_FOR_BACKUPS; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.Strings.encodedLength(shuffleManager); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.Strings.encode(buf, shuffleManager); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RegisterExecutorForBackupsOnly) { + RegisterExecutorForBackupsOnly o = (RegisterExecutorForBackupsOnly) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(shuffleManager, o.shuffleManager); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId, shuffleManager); + } + + @Override + public String toString() { + return Objects.toStringHelper(RegisterExecutorForBackupsOnly.class) + .add("appId", appId) + .add("execId", execId) + .add("shuffleManager", shuffleManager) + .toString(); + } + + public static RegisterExecutorForBackupsOnly decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String shuffleManager = Encoders.Strings.decode(buf); + return new RegisterExecutorForBackupsOnly(appId, execId, shuffleManager); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleFileStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleFileStream.java new file mode 100644 index 0000000000000..409a00c1d89ac --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleFileStream.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +public class UploadShuffleFileStream extends BlockTransferMessage { + public final String appId; + public final String execId; + public final int shuffleId; + public final int mapId; + + public UploadShuffleFileStream( + String appId, + String execId, + int shuffleId, + int mapId) { + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + @Override + protected Type type() { + return Type.UPLOAD_SHUFFLE_FILE_STREAM; + } + + @Override + public int hashCode() { + return Objects.hashCode( + appId, + execId, + shuffleId, + mapId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + 8; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + } + + public static UploadShuffleFileStream decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + return new UploadShuffleFileStream(appId, execId, shuffleId, mapId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexFileStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexFileStream.java new file mode 100644 index 0000000000000..0bd6301517716 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexFileStream.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +public class UploadShuffleIndexFileStream extends BlockTransferMessage { + public final String appId; + public final String execId; + public final int shuffleId; + public final int mapId; + + public UploadShuffleIndexFileStream( + String appId, + String execId, + int shuffleId, + int mapId) { + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + @Override + protected Type type() { + return Type.UPLOAD_SHUFFLE_INDEX_STREAM; + } + + @Override + public int hashCode() { + return Objects.hashCode( + appId, + execId, + shuffleId, + mapId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + 8; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + } + + public static UploadShuffleIndexFileStream decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + return new UploadShuffleIndexFileStream(appId, execId, shuffleId, mapId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java deleted file mode 100644 index b30bb9aed55b6..0000000000000 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol.mesos; - -import io.netty.buffer.ByteBuf; -import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; - -// Needed by ScalaDoc. See SPARK-7726 -import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; - -/** - * A heartbeat sent from the driver to the MesosExternalShuffleService. - */ -public class ShuffleServiceHeartbeat extends BlockTransferMessage { - private final String appId; - - public ShuffleServiceHeartbeat(String appId) { - this.appId = appId; - } - - public String getAppId() { return appId; } - - @Override - protected Type type() { return Type.HEARTBEAT; } - - @Override - public int encodedLength() { return Encoders.Strings.encodedLength(appId); } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - } - - public static ShuffleServiceHeartbeat decode(ByteBuf buf) { - return new ShuffleServiceHeartbeat(Encoders.Strings.decode(buf)); - } -} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 526b96b364473..3e7bc8107924e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -135,7 +135,7 @@ private FetchResult fetchBlocks( try (ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, 5000)) { client.init(APP_ID); - client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, + client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, false, new BlockFetchingListener() { @Override public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1c4fa4bc6541f..a33c7cadea5fb 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -22,6 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration @@ -33,7 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.MetadataFetchFailedException +import org.apache.spark.shuffle.{MetadataFetchFailedException, ShuffleServiceAddressProvider} import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ @@ -210,12 +211,18 @@ private class ShuffleStatus(numPartitions: Int) { } private[spark] sealed trait MapOutputTrackerMessage -private[spark] case class GetMapOutputStatuses(shuffleId: Int) +private[spark] case class GetMapOutputStatuses(shuffleId: Int, getBackup: Boolean) + extends MapOutputTrackerMessage +private[spark] case class ReportBackedUpMapOutput( + shuffleId: Int, mapId: Int, backedUpStatus: MapStatus) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage +private[spark] case object GetBackupShuffleServiceAddresses extends MapOutputTrackerMessage +private[spark] sealed trait BackupMessage +private[spark] case class HeartbeaterMessage(appId: String) extends BackupMessage private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext) - + extends BackupMessage /** RpcEndpoint class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterEndpoint( override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) @@ -224,16 +231,29 @@ private[spark] class MapOutputTrackerMasterEndpoint( logDebug("init") // force eager creation of logger override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case GetMapOutputStatuses(shuffleId: Int) => + case GetMapOutputStatuses(shuffleId: Int, getBackup: Boolean) => val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) - val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context)) + val message = GetMapOutputMessage(shuffleId, context) + if (getBackup) { + tracker.postToBackup[GetMapOutputMessage](message) + } else { + tracker.post[GetMapOutputMessage](message) + } + + case GetBackupShuffleServiceAddresses => + context.reply(tracker.getBackupShuffleServiceAddresses) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") context.reply(true) stop() } + + override def receive(): PartialFunction[Any, Unit] = { + case ReportBackedUpMapOutput(shuffleId, mapId, backedUpStatus) => + tracker.registerBackupMapOutput(shuffleId, mapId, backedUpStatus) + } } /** @@ -283,7 +303,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { - getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) + getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1, false) } /** @@ -295,7 +315,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * and the second item is a sequence of (shuffle block id, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + def getMapSizesByExecutorId( + shuffleId: Int, startPartition: Int, endPartition: Int, getBackup: Boolean) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] /** @@ -318,9 +339,20 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging private[spark] class MapOutputTrackerMaster( conf: SparkConf, broadcastManager: BroadcastManager, - isLocal: Boolean) + isLocal: Boolean, + backupMaster: Option[MapOutputTrackerMaster], + shuffleServiceAddressProvider: ShuffleServiceAddressProvider) extends MapOutputTracker(conf) { + def this( + conf: SparkConf, + broadcastManager: BroadcastManager, + isLocal: Boolean, + shuffleServiceAddressProvider: ShuffleServiceAddressProvider) = this( + conf, broadcastManager, isLocal, Some(new MapOutputTrackerMaster( + conf, broadcastManager, isLocal, None, shuffleServiceAddressProvider)), + shuffleServiceAddressProvider) + // The size at which we use Broadcast to send the map output statuses to the executors private val minSizeForBroadcast = conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt @@ -348,7 +380,7 @@ private[spark] class MapOutputTrackerMaster( private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) // requests for map output statuses - private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] + private val mapOutputRequests = new LinkedBlockingQueue[BackupMessage] // Thread pool used for handling map output status requests. This is a separate thread pool // to ensure we don't block the normal dispatcher threads. @@ -371,10 +403,15 @@ private[spark] class MapOutputTrackerMaster( throw new IllegalArgumentException(msg) } - def post(message: GetMapOutputMessage): Unit = { + def post[T <: BackupMessage](message: T): Unit = { mapOutputRequests.offer(message) } + def postToBackup[T <: BackupMessage](message: T): Unit = { + require(backupMaster.isDefined, "No backup master available") + backupMaster.foreach(_.post(message)) + } + /** Message loop used for dispatching messages. */ private class MessageLoop extends Runnable { override def run(): Unit = { @@ -382,19 +419,20 @@ private[spark] class MapOutputTrackerMaster( while (true) { try { val data = mapOutputRequests.take() - if (data == PoisonPill) { - // Put PoisonPill back so that other MessageLoops can see it. - mapOutputRequests.offer(PoisonPill) - return + data match { + case PoisonPill => + // Put PoisonPill back so that other MessageLoops can see it. + mapOutputRequests.offer(PoisonPill) + return + case GetMapOutputMessage(shuffleId, context) => + val hostPort = context.senderAddress.hostPort + // TODO: Change back to debug + logDebug("Handling request to send map output locations for shuffle " + shuffleId + + " to " + hostPort) + val shuffleStatus = shuffleStatuses.get(shuffleId).head + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast)) } - val context = data.context - val shuffleId = data.shuffleId - val hostPort = context.senderAddress.hostPort - logDebug("Handling request to send map output locations for shuffle " + shuffleId + - " to " + hostPort) - val shuffleStatus = shuffleStatuses.get(shuffleId).head - context.reply( - shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast)) } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -417,12 +455,17 @@ private[spark] class MapOutputTrackerMaster( if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } + backupMaster.foreach(_.registerShuffle(shuffleId, numMaps)) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { shuffleStatuses(shuffleId).addMapOutput(mapId, status) } + def registerBackupMapOutput(shuffleId: Int, mapId: Int, status: MapStatus): Unit = { + backupMaster.foreach(_.registerMapOutput(shuffleId, mapId, status)) + } + /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { shuffleStatuses.get(shuffleId) match { @@ -644,18 +687,27 @@ private[spark] class MapOutputTrackerMaster( } } + def getBackupShuffleServiceAddresses: List[(String, Int)] = + shuffleServiceAddressProvider.getShuffleServiceAddresses() + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + override def getMapSizesByExecutorId( + shuffleId: Int, startPartition: Int, endPartition: Int, getBackup: Boolean) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { - logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - shuffleStatuses.get(shuffleId) match { - case Some (shuffleStatus) => - shuffleStatus.withMapStatuses { statuses => - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) - } - case None => - Iterator.empty + if (getBackup) { + require(backupMaster.isDefined, "Backup master not defined") + backupMaster.get.getMapSizesByExecutorId(shuffleId, startPartition, endPartition, false) + } else { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.withMapStatuses { statuses => + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } + case None => + Iterator.empty + } } } @@ -671,7 +723,7 @@ private[spark] class MapOutputTrackerMaster( /** * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster. * Note that this is not used in local-mode; instead, local-mode Executors access the - * MapOutputTrackerMaster directly (which is possible because the master and worker share a comon + * MapOutputTrackerMaster directly (which is possible because the master and worker share a common * superclass). */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { @@ -679,41 +731,58 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr val mapStatuses: Map[Int, Array[MapStatus]] = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + val backupMapStatuses: Map[Int, Array[MapStatus]] = + new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + /** Remembers which map output locations are currently being fetched on an executor. */ private val fetching = new HashSet[Int] + /** Remembers which backup map output locations are currently being fetched on an executor. */ + private val fetchingBackup = new HashSet[Int] + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. - override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + override def getMapSizesByExecutorId( + shuffleId: Int, startPartition: Int, endPartition: Int, getBackup: Boolean) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - val statuses = getStatuses(shuffleId) + val statuses = if (getBackup) { + getStatuses(shuffleId, backupMapStatuses, fetchingBackup, getBackup) + } else { + getStatuses(shuffleId, mapStatuses, fetching, getBackup) + } try { MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: mapStatuses.clear() + backupMapStatuses.clear() throw e } } /** - * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize - * on this array when reading it, because on the driver, we may be changing it in place. - * - * (It would be nice to remove this restriction in the future.) - */ - private def getStatuses(shuffleId: Int): Array[MapStatus] = { - val statuses = mapStatuses.get(shuffleId).orNull + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses( + shuffleId: Int, + statusesToInspect: Map[Int, Array[MapStatus]], + statusesBeingFetched: mutable.HashSet[Int], + getBackup: Boolean) + : Array[MapStatus] = { + val statuses = statusesToInspect.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") val startTime = System.currentTimeMillis var fetchedStatuses: Array[MapStatus] = null - fetching.synchronized { + statusesBeingFetched.synchronized { // Someone else is fetching it; wait for them to be done - while (fetching.contains(shuffleId)) { + while (statusesBeingFetched.contains(shuffleId)) { try { - fetching.wait() + statusesBeingFetched.wait() } catch { case e: InterruptedException => } @@ -721,10 +790,10 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr // Either while we waited the fetch happened successfully, or // someone fetched it in between the get and the fetching.synchronized. - fetchedStatuses = mapStatuses.get(shuffleId).orNull + fetchedStatuses = statusesToInspect.get(shuffleId).orNull if (fetchedStatuses == null) { // We have to do the fetch, get others to wait for us. - fetching += shuffleId + statusesBeingFetched += shuffleId } } @@ -733,14 +802,14 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { - val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId, getBackup)) fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) + statusesToInspect.put(shuffleId, fetchedStatuses) } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() + statusesBeingFetched.synchronized { + statusesBeingFetched -= shuffleId + statusesBeingFetched.notifyAll() } } } @@ -759,7 +828,6 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } - /** Unregister shuffle data. */ def unregisterShuffle(shuffleId: Int): Unit = { mapStatuses.remove(shuffleId) @@ -835,7 +903,6 @@ private[spark] object MapOutputTracker extends Logging { objIn.close() } } - bytes(0) match { case DIRECT => deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]] diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 66038eeaea54f..50a7c80d73474 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -19,13 +19,13 @@ package org.apache.spark import java.io.File import java.net.Socket -import java.util.Locale +import java.util.{Locale, ServiceLoader} +import com.google.common.collect.MapMaker +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Properties -import com.google.common.collect.MapMaker - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager @@ -39,7 +39,7 @@ import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleManager, ShuffleServiceAddressProviderFactory} import org.apache.spark.storage._ import org.apache.spark.util.{RpcUtils, Utils} @@ -302,7 +302,20 @@ object SparkEnv extends Logging { val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf, broadcastManager, isLocal) + val loader = Utils.getContextOrSparkClassLoader + val master = conf.get("spark.master") + val serviceLoaders = + ServiceLoader.load(classOf[ShuffleServiceAddressProviderFactory], loader) + .asScala.filter(_.canCreate(conf.get("spark.master"))) + if (serviceLoaders.size > 1) { + throw new SparkException( + s"Multiple external cluster managers registered for the url $master: $serviceLoaders") + } + val shuffleServiceAddressProvider = serviceLoaders.headOption + .map(_.create(conf)) + .getOrElse(DefaultShuffleServiceAddressProvider) + shuffleServiceAddressProvider.start() + new MapOutputTrackerMaster(conf, broadcastManager, isLocal, shuffleServiceAddressProvider) } else { new MapOutputTrackerWorker(conf) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d34601358d896..d3fae7e4eb4f7 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -103,6 +103,11 @@ package object config { private[spark] val EXECUTOR_HEARTBEAT_MAX_FAILURES = ConfigBuilder("spark.executor.heartbeat.maxFailures").internal().intConf.createWithDefault(60) + private[spark] val SHUFFLE_BACKUP_HEARTBEAT_INTERVAL = + ConfigBuilder("spark.driver.externalShuffleBackup.heartbeatInterval") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("10s") + private[spark] val EXECUTOR_JAVA_OPTIONS = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS).stringConf.createOptional diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index a58c8fa2e763f..ac59119a36230 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -67,6 +67,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockIds: Array[String], + isBackup: Boolean, listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit @@ -92,10 +93,11 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockId: String, + isBackup: Boolean, tempFileManager: DownloadFileManager): ManagedBuffer = { // A monitor for the thread to wait on. val result = Promise[ManagedBuffer]() - fetchBlocks(host, port, execId, Array(blockId), + fetchBlocks(host, port, execId, Array(blockId), isBackup, new BlockFetchingListener { override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { result.failure(exception) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index dc55685b1e7bd..cc68d825fc169 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -105,6 +105,7 @@ private[spark] class NettyBlockTransferService( port: Int, execId: String, blockIds: Array[String], + isBackup: Boolean, listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 64f0a060a247c..6ae8554e7bad5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -243,3 +243,31 @@ private[spark] object HighlyCompressedMapStatus { hugeBlockSizes) } } + +private[spark] class RelocatedMapStatus private( + private[this] var original: MapStatus, + private[this] var newLocation: BlockManagerId) + extends MapStatus with Externalizable { + + protected def this() = this(null, null) + + override def location: BlockManagerId = newLocation + + override def getSizeForBlock(reduceId: Int): Long = original.getSizeForBlock(reduceId) + + override def writeExternal(out: ObjectOutput): Unit = { + out.writeObject(original) + out.writeObject(newLocation) + } + + override def readExternal(in: ObjectInput): Unit = { + this.original = in.readObject().asInstanceOf[MapStatus] + this.newLocation = in.readObject().asInstanceOf[BlockManagerId] + } +} + +private[spark] object RelocatedMapStatus { + def apply(original: MapStatus, newLocation: BlockManagerId): MapStatus = { + new RelocatedMapStatus(original, newLocation) + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/BackingUpShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/BackingUpShuffleWriter.scala new file mode 100644 index 0000000000000..4b8427d78b1f1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/BackingUpShuffleWriter.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.File +import java.nio.ByteBuffer +import java.util.concurrent.ExecutorService + +import com.google.common.util.concurrent.SettableFuture +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} + +import org.apache.spark.{MapOutputTracker, ReportBackedUpMapOutput} +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, UploadShuffleFileStream, UploadShuffleIndexFileStream} +import org.apache.spark.network.util.TransportConf +import org.apache.spark.scheduler.{MapStatus, RelocatedMapStatus} +import org.apache.spark.storage.BlockManagerId + +class BackingUpShuffleWriter[K, V]( + shuffleBlockResolver: IndexShuffleBlockResolver, + delegateWriter: ShuffleWriter[K, V], + backupShuffleServiceClient: TransportClient, + transportConf: TransportConf, + mapOutputTracker: MapOutputTracker, + backupExecutor: ExecutorService, + backupHost: String, + backupPort: Int, + appId: String, + execId: String, + shuffleId: Int, + mapId: Int) + extends ShuffleWriter[K, V] with Logging { + + private implicit val backupExecutorContext = ExecutionContext.fromExecutorService(backupExecutor) + + /** Write a sequence of records to this task's output */ + override def write(records: Iterator[Product2[K, V]]): Unit = { + delegateWriter.write(records) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + val delegateMapStatus = delegateWriter.stop(success) + delegateMapStatus.foreach { _ => + val outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId) + val indexFile = shuffleBlockResolver.getIndexFile(shuffleId, mapId) + if (outputFile.isFile && indexFile.isFile) { + val uploadBackupFileRequest = new UploadShuffleFileStream( + appId, execId, shuffleId, mapId) + val uploadIndexFileRequest = new UploadShuffleIndexFileStream( + appId, execId, shuffleId, mapId) + + val backupFileTask: Future[Unit] = Future { + backupFile(outputFile, uploadBackupFileRequest) + backupFile(indexFile, uploadIndexFileRequest) + } + + backupFileTask.onComplete { + case Success(_) => + val backedUpMapStatus = RelocatedMapStatus( + delegateMapStatus.get, + BlockManagerId(execId, backupHost, backupPort, None, isBackup = true)) + mapOutputTracker.trackerEndpoint.send( + ReportBackedUpMapOutput(shuffleId, mapId, backedUpMapStatus)) + case Failure(_) => logError("An error has occured in backing up") + } + } + } + delegateMapStatus + } + + private def backupFile( + fileToBackUp: File, + backupFileRequest: BlockTransferMessage) { + val dataFileBuffer = new FileSegmentManagedBuffer( + transportConf, fileToBackUp, 0, fileToBackUp.length()) + val uploadBackupRequestBuffer = new NioManagedBuffer(backupFileRequest.toByteBuffer) + val awaitCompletion = SettableFuture.create[Boolean] + backupShuffleServiceClient.uploadStream( + uploadBackupRequestBuffer, dataFileBuffer, new RpcResponseCallback { + override def onSuccess(response: ByteBuffer): Unit = { + logInfo("Successfully backed up shuffle map data file" + + s" (shuffle id: $shuffleId, map id: $mapId, executor id: $execId)") + awaitCompletion.set(true) + } + + /** Exception either propagated from server or raised on client side. */ + override def onFailure(e: Throwable): Unit = { + logError("Failed to back up shuffle map data file" + + s" (shuffle id: $shuffleId, map id: $mapId, executor id: $execId)") + awaitCompletion.setException(e) + } + }) + awaitCompletion.get() + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 74b0e0b3a741a..b8c20cdceade9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -46,7 +46,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( context, blockManager.shuffleClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), + mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startPartition, endPartition, context.attemptNumber() > 0), serializerManager.wrapStream, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, diff --git a/core/src/main/scala/org/apache/spark/shuffle/ExternalFallbackShuffleClient.scala b/core/src/main/scala/org/apache/spark/shuffle/ExternalFallbackShuffleClient.scala new file mode 100644 index 0000000000000..1b0a6941b99d0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ExternalFallbackShuffleClient.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.shuffle._ + +private[spark] class ExternalFallbackShuffleClient( + externalShuffleClient: ExternalShuffleClient, + baseBlockTransferService: BlockTransferService) extends ShuffleClient { + + override def init(appId: String): Unit = { + externalShuffleClient.init(appId) + baseBlockTransferService.init(appId) + } + + override def fetchBlocks( + host: String, + port: Int, + execId: String, + blockIds: Array[String], + isBackup: Boolean, + listener: BlockFetchingListener, + downloadFileManager: DownloadFileManager): Unit = { + if (isBackup) { + externalShuffleClient.fetchBlocks( + host, + port, + execId, + blockIds, + isBackup, + listener, + downloadFileManager) + } else { + baseBlockTransferService.fetchBlocks( + host, port, execId, blockIds, isBackup, listener, downloadFileManager) + } + } + + override def close(): Unit = { + baseBlockTransferService.close() + externalShuffleClient.close() + } + + def registerWithShuffleServerForBackups( + host: String, + port: Int, + execId: String, + shuffleManager: String) : Unit = { + externalShuffleClient.registerWithShuffleServerForBackups(host, port, execId, shuffleManager) + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d3f1c7ec1bbee..e5aad891541f6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -55,7 +55,7 @@ private[spark] class IndexShuffleBlockResolver( blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } - private def getIndexFile(shuffleId: Int, mapId: Int): File = { + def getIndexFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala similarity index 60% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala rename to core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala index 83daddf714489..eb86eca47e538 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala @@ -14,24 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.scheduler.cluster.k8s -import io.fabric8.kubernetes.api.model.Pod +package org.apache.spark.shuffle -sealed trait ExecutorPodState { - def pod: Pod -} - -case class PodRunning(pod: Pod) extends ExecutorPodState - -case class PodPending(pod: Pod) extends ExecutorPodState +trait ShuffleServiceAddressProvider { -sealed trait FinalPodState extends ExecutorPodState + def start(): Unit = {} -case class PodSucceeded(pod: Pod) extends FinalPodState + def getShuffleServiceAddresses(): List[(String, Int)] -case class PodFailed(pod: Pod) extends FinalPodState - -case class PodDeleted(pod: Pod) extends FinalPodState + def stop(): Unit = {} +} -case class PodUnknown(pod: Pod) extends ExecutorPodState +private[spark] object DefaultShuffleServiceAddressProvider extends ShuffleServiceAddressProvider { + override def getShuffleServiceAddresses(): List[(String, Int)] = List.empty[(String, Int)] +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala new file mode 100644 index 0000000000000..abfe7dc156b25 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import org.apache.spark.SparkConf + +trait ShuffleServiceAddressProviderFactory { + def canCreate(masterUrl: String): Boolean + + def create(conf: SparkConf): ShuffleServiceAddressProvider +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 0caf84c6050a8..f50f9b18e36e2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -17,11 +17,18 @@ package org.apache.spark.shuffle.sort +import java.net.URI import java.util.concurrent.ConcurrentHashMap +import scala.util.Random + import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.server.NoOpRpcHandler import org.apache.spark.shuffle._ +import org.apache.spark.util.ThreadUtils /** * In sort-based shuffle, incoming records are sorted according to their target partition ids, then @@ -81,6 +88,22 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + private val backupShuffleTransportConf = SparkTransportConf.fromSparkConf( + conf, "shuffle", 2) + + private lazy val backupShuffleTransportClients = SparkEnv + .get + .blockManager + .getBackupShuffleServiceAddresses() + .map(address => { + val addressAsUri = URI.create(s"spark://${address._1}:${address._2}") + val transportContext = new TransportContext( + backupShuffleTransportConf, + new NoOpRpcHandler(), + false) + (address, addressAsUri, transportContext.createClientFactory()) + }) + /** * Obtains a [[ShuffleHandle]] to pass to tasks. */ @@ -127,7 +150,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager numMapsForShuffle.putIfAbsent( handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) val env = SparkEnv.get - handle match { + val baseWriter = handle match { case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( env.blockManager, @@ -148,6 +171,26 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) } + Random.shuffle(backupShuffleTransportClients) + .headOption + .map(addressAndClient => { + val transportClient = + addressAndClient._3.createClient( + addressAndClient._2.getHost, addressAndClient._2.getPort) + new BackingUpShuffleWriter( + shuffleBlockResolver, + baseWriter, + transportClient, + backupShuffleTransportConf, + env.mapOutputTracker, + ThreadUtils.newDaemonCachedThreadPool("backup-shuffle-files"), + addressAndClient._1._1, + addressAndClient._1._2, + conf.getAppId, + env.blockManager.blockManagerId.executorId, + handle.shuffleId, + mapId) + }).getOrElse(baseWriter) } /** Remove a shuffle's metadata from the ShuffleManager. */ @@ -163,6 +206,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager /** Shut down this ShuffleManager. */ override def stop(): Unit = { shuffleBlockResolver.stop() + backupShuffleTransportClients.foreach(_._3.close()) } } @@ -174,7 +218,7 @@ private[spark] object SortShuffleManager extends Logging { * buffering map outputs in a serialized form. This is an extreme defensive programming measure, * since it's extremely unlikely that a single shuffle produces over 16 million output partitions. * */ - val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE = + val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE: Int = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index edae2f95fce33..2994911cdcef3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -24,6 +24,7 @@ import java.nio.channels.Channels import java.util.Collections import java.util.concurrent.ConcurrentHashMap +import com.codahale.metrics.{MetricRegistry, MetricSet} import scala.collection.mutable import scala.collection.mutable.HashMap import scala.concurrent.{ExecutionContext, Future} @@ -32,9 +33,6 @@ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal -import com.codahale.metrics.{MetricRegistry, MetricSet} -import com.google.common.io.CountingOutputStream - import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.{config, Logging} @@ -45,12 +43,12 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle._ -import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo +import org.apache.spark.network.shuffle.protocol.{ExecutorShuffleInfo, ExternalServiceHeartbeat, RegisterDriver} import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{ExternalFallbackShuffleClient, ShuffleManager} import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ @@ -183,15 +181,11 @@ private[spark] class BlockManager( // service, or just our own Executor's BlockManager. private[spark] var shuffleServerId: BlockManagerId = _ + private var backupShuffleServiceAddresses: List[(String, Int)] = _ + // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. - private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) - new ExternalShuffleClient(transConf, securityManager, - securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) - } else { - blockTransferService - } + private[spark] var shuffleClient: ShuffleClient = _ // Max number of failures before this block manager refreshes the block locations from the driver private val maxFailuresBeforeLocationRefresh = @@ -220,6 +214,7 @@ private[spark] class BlockManager( new BlockManager.RemoteBlockDownloadFileManager(this) private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -241,7 +236,6 @@ private[spark] class BlockManager( logInfo(s"Using $priorityClass for block replication policy") ret } - val id = BlockManagerId(executorId, blockTransferService.hostName, blockTransferService.port, None) @@ -253,6 +247,48 @@ private[spark] class BlockManager( blockManagerId = if (idFromMaster != null) idFromMaster else id + backupShuffleServiceAddresses = if (blockManagerId.isDriver) { + List.empty[(String, Int)] + } else { + Random.shuffle(mapOutputTracker + .trackerEndpoint + .askSync[List[(String, Int)]](GetBackupShuffleServiceAddresses)) + .take(3) + } + + if (backupShuffleServiceAddresses.nonEmpty) { + require(!externalShuffleServiceEnabled, "Cannot use external shuffle service with backup" + + " shuffle services.") + } + shuffleClient = if (externalShuffleServiceEnabled) { + require(backupShuffleServiceAddresses.isEmpty, + "Cannot use the external shuffle service while using backup shuffle services.") + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) + new ExternalShuffleClient(transConf, + securityManager, + securityManager.isAuthenticationEnabled(), + conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) + } else if (backupShuffleServiceAddresses.nonEmpty) { + logInfo("Using BackupShuffleService") + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) + val externalShuffleClient = new ExternalShuffleClient( + transConf, + securityManager, + securityManager.isAuthenticationEnabled(), + conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) + new ExternalFallbackShuffleClient(externalShuffleClient, blockTransferService) + } else blockTransferService + shuffleClient.init(appId) + + blockReplicationPolicy = { + val priorityClass = conf.get( + "spark.storage.replication.policy", classOf[RandomBlockReplicationPolicy].getName) + val clazz = Utils.classForName(priorityClass) + val ret = clazz.newInstance.asInstanceOf[BlockReplicationPolicy] + logInfo(s"Using $priorityClass for block replication policy") + ret + } + shuffleServerId = if (externalShuffleServiceEnabled) { logInfo(s"external shuffle service port = $externalShuffleServicePort") BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) @@ -265,6 +301,13 @@ private[spark] class BlockManager( registerWithExternalShuffleServer() } + backupShuffleServiceAddresses.foreach(address => { + registerWithBackupShuffleServer( + address._1, + address._2, + appId) + }) + logInfo(s"Initialized BlockManager: $blockManagerId") } @@ -306,6 +349,36 @@ private[spark] class BlockManager( } } + private def registerWithBackupShuffleServer( + shuffleServerHost: String, + shuffleServerPort: Int, + appId: String) + : Unit = { + val MAX_ATTEMPTS = conf.get(config.SHUFFLE_REGISTRATION_MAX_ATTEMPTS) + val SLEEP_TIME_SECS = 5 + for (i <- 1 to MAX_ATTEMPTS) { + try { + // Synchronous and will throw an exception if we cannot connect. + shuffleClient + .asInstanceOf[ExternalFallbackShuffleClient] + .registerWithShuffleServerForBackups( + shuffleServerHost, + shuffleServerPort, + shuffleServerId.executorId, + shuffleManager.getClass.getName) + return + } catch { + case e: Exception if i < MAX_ATTEMPTS => + logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) + Thread.sleep(SLEEP_TIME_SECS * 1000L) + case NonFatal(e) => + throw new SparkException("Unable to register with external shuffle server due to : " + + e.getMessage, e) + } + } + } + /** * Report all blocks to the BlockManager again. This may be necessary if we are dropped * by the BlockManager and come back or if we become capable of recovering blocks on disk after @@ -755,7 +828,7 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager) + loc.host, loc.port, loc.executorId, blockId.toString, loc.isBackup, tempFileManager) } catch { case NonFatal(e) => runningFailureCount += 1 @@ -1603,6 +1676,8 @@ private[spark] class BlockManager( } } + def getBackupShuffleServiceAddresses(): List[(String, Int)] = backupShuffleServiceAddresses + def releaseLockAndDispose( blockId: BlockId, data: BlockData, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index d4a59c33b974c..3b8a03bb23713 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -39,7 +39,8 @@ class BlockManagerId private ( private var executorId_ : String, private var host_ : String, private var port_ : Int, - private var topologyInfo_ : Option[String]) + private var topologyInfo_ : Option[String], + private var isBackup_ : Boolean = false) extends Externalizable { private def this() = this(null, null, 0, None) // For deserialization only @@ -62,6 +63,8 @@ class BlockManagerId private ( def port: Int = port_ + def isBackup: Boolean = isBackup_ + def topologyInfo: Option[String] = topologyInfo_ def isDriver: Boolean = { @@ -69,10 +72,12 @@ class BlockManagerId private ( executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER } + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) out.writeUTF(host_) out.writeInt(port_) + out.writeBoolean(isBackup) out.writeBoolean(topologyInfo_.isDefined) // we only write topologyInfo if we have it topologyInfo.foreach(out.writeUTF(_)) @@ -82,6 +87,7 @@ class BlockManagerId private ( executorId_ = in.readUTF() host_ = in.readUTF() port_ = in.readInt() + isBackup_ = in.readBoolean() val isTopologyInfoAvailable = in.readBoolean() topologyInfo_ = if (isTopologyInfoAvailable) Option(in.readUTF()) else None } @@ -124,8 +130,10 @@ private[spark] object BlockManagerId { execId: String, host: String, port: Int, - topologyInfo: Option[String] = None): BlockManagerId = - getCachedBlockManagerId(new BlockManagerId(execId, host, port, topologyInfo)) + topologyInfo: Option[String] = None, + isBackup: Boolean = false): BlockManagerId = + getCachedBlockManagerId(new BlockManagerId( + execId, host, port, topologyInfo, isBackup)) def apply(in: ObjectInput): BlockManagerId = { val obj = new BlockManagerId() diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index aecc2284a9588..f7b7b68e53a7e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -254,11 +254,22 @@ final class ShuffleBlockFetcherIterator( // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, this) + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + address.isBackup, + blockFetchingListener, this) } else { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, null) + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + address.isBackup, + blockFetchingListener, + null) } } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 629a323042ff2..a96ea420e7f0a 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -189,8 +189,13 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(locations.size === storageLevel.replication, s"; got ${locations.size} replicas instead of ${storageLevel.replication}") locations.foreach { cmId => - val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, - blockId.toString, null) + val bytes = blockTransfer.fetchBlockSync( + cmId.host, + cmId.port, + cmId.executorId, + blockId.toString, + cmId.isBackup, + null) val deserialized = serializerManager.dataDeserializeStream(blockId, new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 21f481d477242..e5e97a496b944 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer - import org.mockito.Matchers.any import org.mockito.Mockito._ @@ -26,7 +25,7 @@ import org.apache.spark.LocalSparkContext._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} -import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, FetchFailedException} import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { @@ -35,7 +34,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { private def newTrackerMaster(sparkConf: SparkConf = conf) = { val broadcastManager = new BroadcastManager(true, sparkConf, new SecurityManager(sparkConf)) - new MapOutputTrackerMaster(sparkConf, broadcastManager, true) + new MapOutputTrackerMaster( + sparkConf, + broadcastManager, + true, + DefaultShuffleServiceAddressProvider) } def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, @@ -186,7 +189,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) - masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10, false)) // Default size for broadcast in this testsuite is set to -1 so should not cause broadcast // to be used. verify(rpcCallContext, timeout(30000)).reply(any()) @@ -265,7 +268,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) - masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20, false)) // should succeed since majority of data is broadcast and actual serialized // message size is small verify(rpcCallContext, timeout(30000)).reply(any()) @@ -313,7 +316,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + assert(tracker.getMapSizesByExecutorId(10, 0, 4, false).toSeq === Seq( (BlockManagerId("a", "hostA", 1000), Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 21138bd4a16ba..8c7065d5728bd 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -156,7 +156,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi val promise = Promise[ManagedBuffer]() - self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString), + self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString), false, new BlockFetchingListener { override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { promise.failure(exception) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5f4ffa151d19b..adad66e788ce7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -24,7 +24,6 @@ import scala.annotation.meta.param import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal - import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ @@ -34,7 +33,7 @@ import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode -import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils} @@ -250,7 +249,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi results.clear() securityMgr = new SecurityManager(conf) broadcastManager = new BroadcastManager(true, conf, securityMgr) - mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) { + mapOutputTracker = new MapOutputTrackerMaster( + conf, broadcastManager, true, DefaultShuffleServiceAddressProvider) { override def sendTracker(message: Any): Unit = { // no-op, just so we can stop this to avoid leaking threads } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 2d8a83c6fabed..f0096e60841e6 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -101,7 +101,8 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { + when(mapOutputTracker.getMapSizesByExecutorId( + shuffleId, reduceId, reduceId + 1, false)).thenReturn { // Test a scenario where all data is local, to avoid creating a bunch of additional mocks // for the code to read data over the network. val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 3962bdc27d22c..a176ca100807f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps - import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ @@ -38,6 +37,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} +import org.apache.spark.shuffle.DefaultShuffleServiceAddressProvider import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ import org.apache.spark.util.Utils @@ -53,7 +53,8 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite protected var master: BlockManagerMaster = null protected lazy val securityMgr = new SecurityManager(conf) protected lazy val bcastManager = new BroadcastManager(true, conf, securityMgr) - protected lazy val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) + protected lazy val mapOutputTracker = new MapOutputTrackerMaster( + conf, bcastManager, true, DefaultShuffleServiceAddressProvider) protected lazy val shuffleManager = new SortShuffleManager(conf) // List of block manager created during an unit test, so that all of the them can be stopped diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 32d6e8b94e1a2..692563cc8192f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -25,7 +25,6 @@ import scala.concurrent.Future import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} import scala.reflect.ClassTag - import org.apache.commons.lang3.RandomUtils import org.mockito.{Matchers => mc} import org.mockito.Mockito.{mock, times, verify, when} @@ -50,6 +49,7 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} +import org.apache.spark.shuffle.DefaultShuffleServiceAddressProvider import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util._ @@ -72,7 +72,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE var master: BlockManagerMaster = null val securityMgr = new SecurityManager(new SparkConf(false)) val bcastManager = new BroadcastManager(true, new SparkConf(false), securityMgr) - val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true) + val mapOutputTracker = new MapOutputTrackerMaster( + new SparkConf(false), bcastManager, true, DefaultShuffleServiceAddressProvider) val shuffleManager = new SortShuffleManager(new SparkConf(false)) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test @@ -1446,6 +1447,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockIds: Array[String], + isBackup: Boolean, listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) @@ -1474,13 +1476,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockId: String, + isBackup: Boolean, tempFileManager: DownloadFileManager): ManagedBuffer = { numCalls += 1 this.tempFileManager = tempFileManager if (numCalls <= maxFailures) { throw new RuntimeException("Failing block fetch in the mock block transfer service") } - super.fetchBlockSync(host, port, execId, blockId, tempFileManager) + super.fetchBlockSync(host, port, execId, blockId, isBackup, tempFileManager) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index b268195e09a5b..2e694d725441e 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -46,7 +46,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] @@ -140,7 +140,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -159,7 +159,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -297,7 +297,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -337,7 +337,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -415,7 +415,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -479,7 +479,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) var tempFileManager: DownloadFileManager = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] diff --git a/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.shuffle.ShuffleServiceAddressProviderFactory b/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.shuffle.ShuffleServiceAddressProviderFactory new file mode 100644 index 0000000000000..c39f2ec633ea3 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.shuffle.ShuffleServiceAddressProviderFactory @@ -0,0 +1 @@ +org.apache.spark.shuffle.k8s.KubernetesShuffleServiceAddressProviderFactory \ No newline at end of file diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index a32bd93bb65bc..48aff1aa5bb8a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -276,6 +276,25 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_BACKUP_SHUFFLE_SERVICE_ENABLED = + ConfigBuilder("spark.kubernetes.shuffle.service.backups.enabled") + .doc("Use shuffle service to back up shuffle data in Kubernetes applications.") + .booleanConf + .createWithDefault(false) + + val KUBERNETES_BACKUP_SHUFFLE_SERVICE_PODS_NAMESPACE = + ConfigBuilder("spark.kubernetes.shuffle.service.backups.pods.namespace") + .doc("Namespace of the pods that are running the shuffle service instances for backing up" + + " shuffle data.") + .stringConf + .createOptional + + val KUBERNETES_BACKUP_SHUFFLE_SERVICE_PORT = + ConfigBuilder("spark.kubernetes.shuffle.service.backups.port") + .doc("Port of the shuffle services that will back up the application's shuffle data.") + .intConf + .createWithDefault(7337) + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" @@ -304,4 +323,7 @@ private[spark] object Config extends Logging { val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." + + val KUBERNETES_BACKUP_SHUFFLE_SERVICE_LABELS = + "spark.kubernetes.shuffle.service.backups.label." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala index 77bd66b608e7c..86212f9cce2af 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -21,11 +21,13 @@ import java.io.File import com.google.common.base.Charsets import com.google.common.io.Files import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient, KubernetesClient} +import io.fabric8.kubernetes.client.Config._ import io.fabric8.kubernetes.client.utils.HttpClientUtils import okhttp3.Dispatcher import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.util.ThreadUtils /** @@ -35,6 +37,36 @@ import org.apache.spark.util.ThreadUtils */ private[spark] object SparkKubernetesClientFactory { + def getDriverKubernetesClient(conf: SparkConf, masterURL: String): KubernetesClient = { + val wasSparkSubmittedInClusterMode = conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK) + val (authConfPrefix, + apiServerUri, + defaultServiceAccountToken, + defaultServiceAccountCaCrt) = if (wasSparkSubmittedInClusterMode) { + require(conf.get(KUBERNETES_DRIVER_POD_NAME).isDefined, + "If the application is deployed using spark-submit in cluster mode, the driver pod name " + + "must be provided.") + (KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + KUBERNETES_MASTER_INTERNAL_URL, + Some(new File(KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), + Some(new File(KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + } else { + (KUBERNETES_AUTH_CLIENT_MODE_PREFIX, + KubernetesUtils.parseMasterUrl(masterURL), + None, + None) + } + + val kubernetesClient = createKubernetesClient( + apiServerUri, + Some(conf.get(KUBERNETES_NAMESPACE)), + authConfPrefix, + conf, + defaultServiceAccountToken, + defaultServiceAccountCaCrt) + kubernetesClient + } + def createKubernetesClient( master: String, namespace: Option[String], diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala index 435a5f1461c92..afd97240255bc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging /** * An immutable view of the current executor pods that are running in the cluster. */ -private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) { +private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, SparkPodState]) { import ExecutorPodsSnapshot._ @@ -42,15 +42,15 @@ object ExecutorPodsSnapshot extends Logging { ExecutorPodsSnapshot(toStatesByExecutorId(executorPods)) } - def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState]) + def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, SparkPodState]) - private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = { + private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, SparkPodState] = { executorPods.map { pod => - (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod)) + (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, SparkPodState.toState(pod)) }.toMap } - private def toState(pod: Pod): ExecutorPodState = { + private def toState(pod: Pod): SparkPodState = { if (isDeleted(pod)) { PodDeleted(pod) } else { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index ce10f766334ff..c9bbc51a6bde7 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -16,16 +16,13 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.io.File +import java.lang import java.util.concurrent.TimeUnit import com.google.common.cache.CacheBuilder -import io.fabric8.kubernetes.client.Config import org.apache.spark.SparkContext -import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.util.{SystemClock, ThreadUtils} @@ -42,32 +39,8 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val wasSparkSubmittedInClusterMode = sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK) - val (authConfPrefix, - apiServerUri, - defaultServiceAccountToken, - defaultServiceAccountCaCrt) = if (wasSparkSubmittedInClusterMode) { - require(sc.conf.get(KUBERNETES_DRIVER_POD_NAME).isDefined, - "If the application is deployed using spark-submit in cluster mode, the driver pod name " + - "must be provided.") - (KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, - KUBERNETES_MASTER_INTERNAL_URL, - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - } else { - (KUBERNETES_AUTH_CLIENT_MODE_PREFIX, - KubernetesUtils.parseMasterUrl(masterURL), - None, - None) - } - - val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( - apiServerUri, - Some(sc.conf.get(KUBERNETES_NAMESPACE)), - authConfPrefix, - sc.conf, - defaultServiceAccountToken, - defaultServiceAccountCaCrt) + val kubernetesClient = SparkKubernetesClientFactory.getDriverKubernetesClient( + sc.conf, masterURL) if (sc.conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) { KubernetesUtils.loadPodFromTemplate( @@ -85,7 +58,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit val snapshotsStore = new ExecutorPodsSnapshotsStoreImpl(subscribersExecutor) val removedExecutorsCache = CacheBuilder.newBuilder() .expireAfterWrite(3, TimeUnit.MINUTES) - .build[java.lang.Long, java.lang.Long]() + .build[lang.Long, lang.Long]() val executorPodsLifecycleEventHandler = new ExecutorPodsLifecycleManager( sc.conf, kubernetesClient, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala new file mode 100644 index 0000000000000..0910b787b7f4b --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.internal.Logging + +sealed trait SparkPodState { + def pod: Pod +} + +case class PodRunning(pod: Pod) extends SparkPodState + +case class PodPending(pod: Pod) extends SparkPodState + +sealed trait FinalPodState extends SparkPodState + +case class PodSucceeded(pod: Pod) extends FinalPodState + +case class PodFailed(pod: Pod) extends FinalPodState + +case class PodDeleted(pod: Pod) extends FinalPodState + +case class PodUnknown(pod: Pod) extends SparkPodState + +object SparkPodState extends Logging { + def toState(pod: Pod): SparkPodState = { + if (isDeleted(pod)) { + PodDeleted(pod) + } else { + val phase = pod.getStatus.getPhase.toLowerCase + phase match { + case "pending" => + PodPending(pod) + case "running" => + PodRunning(pod) + case "failed" => + PodFailed(pod) + case "succeeded" => + PodSucceeded(pod) + case _ => + logWarning(s"Received unknown phase $phase for executor pod with name" + + s" ${pod.getMetadata.getName} in namespace ${pod.getMetadata.getNamespace}") + PodUnknown(pod) + } + } + } + + private def isDeleted(pod: Pod): Boolean = pod.getMetadata.getDeletionTimestamp != null +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala new file mode 100644 index 0000000000000..00adbc6426a28 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.k8s + +import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} +import java.util.concurrent.locks.ReentrantReadWriteLock + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watch, Watcher} +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.cluster.k8s._ +import org.apache.spark.shuffle.ShuffleServiceAddressProvider +import org.apache.spark.util.Utils + +class KubernetesShuffleServiceAddressProvider( + kubernetesClient: KubernetesClient, + pollForPodsExecutor: ScheduledExecutorService, + podLabels: Map[String, String], + namespace: String, + portNumber: Int) + extends ShuffleServiceAddressProvider with Logging { + + // General implementation remark: this bears a strong resemblance to ExecutorPodsSnapshotsStore, + // but we don't need all "in-between" lists of all executor pods, just the latest known list + // when we query in getShuffleServiceAddresses. + + private val podsUpdateLock = new ReentrantReadWriteLock() + + private val shuffleServicePods = mutable.HashMap.empty[String, Pod] + + private var shuffleServicePodsWatch: Watch = _ + private var pollForPodsTask: ScheduledFuture[_] = _ + + override def start(): Unit = { + pollForPods() + pollForPodsTask = pollForPodsExecutor.scheduleWithFixedDelay( + () => pollForPods(), 0, 10, TimeUnit.SECONDS) + shuffleServicePodsWatch = kubernetesClient + .pods() + .inNamespace(namespace) + .withLabels(podLabels.asJava).watch(new PutPodsInCacheWatcher()) + } + + override def stop(): Unit = { + Utils.tryLogNonFatalError { + if (pollForPodsTask != null) { + pollForPodsTask.cancel(false) + } + } + + Utils.tryLogNonFatalError { + if (shuffleServicePodsWatch != null) { + shuffleServicePodsWatch.close() + } + } + + Utils.tryLogNonFatalError { + kubernetesClient.close() + } + } + + override def getShuffleServiceAddresses(): List[(String, Int)] = { + val readLock = podsUpdateLock.readLock() + readLock.lock() + try { + val addresses = shuffleServicePods.values.map(pod => { + (pod.getStatus.getPodIP, portNumber) + }).toList + logInfo(s"Found backup shuffle service addresses at $addresses.") + addresses + } finally { + readLock.unlock() + } + } + + private def pollForPods(): Unit = { + val writeLock = podsUpdateLock.writeLock() + writeLock.lock() + try { + val allPods = kubernetesClient + .pods() + .inNamespace(namespace) + .withLabels(podLabels.asJava) + .list() + shuffleServicePods.clear() + allPods.getItems.asScala.foreach(updatePod) + } finally { + writeLock.unlock() + } + } + + private def updatePod(pod: Pod): Unit = { + require(podsUpdateLock.isWriteLockedByCurrentThread, "Should only update pods under lock.") + val state = SparkPodState.toState(pod) + state match { + case PodPending(_) | PodFailed(_) | PodSucceeded(_) | PodDeleted(_) => + shuffleServicePods.remove(pod.getMetadata.getName) + case PodRunning(_) => + shuffleServicePods.put(pod.getMetadata.getName, pod) + case _ => + logWarning(s"Unknown state $state for pod named ${pod.getMetadata.getName}") + } + } + + private def deletePod(pod: Pod): Unit = { + require(podsUpdateLock.isWriteLockedByCurrentThread, "Should only delete under lock.") + shuffleServicePods.remove(pod.getMetadata.getName) + } + + private class PutPodsInCacheWatcher extends Watcher[Pod] { + override def eventReceived(action: Watcher.Action, pod: Pod): Unit = { + val writeLock = podsUpdateLock.writeLock() + writeLock.lock() + try { + updatePod(pod) + } finally { + writeLock.unlock() + } + } + + override def onClose(e: KubernetesClientException): Unit = {} + } + + private implicit def toRunnable(func: () => Unit): Runnable = { + new Runnable { + override def run(): Unit = func() + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala new file mode 100644 index 0000000000000..32f56ee6b4396 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.k8s + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProvider, ShuffleServiceAddressProviderFactory} +import org.apache.spark.util.ThreadUtils + +class KubernetesShuffleServiceAddressProviderFactory extends ShuffleServiceAddressProviderFactory { + override def canCreate(masterUrl: String): Boolean = masterUrl.startsWith("k8s://") + + override def create(conf: SparkConf): ShuffleServiceAddressProvider = { + if (conf.get(KUBERNETES_BACKUP_SHUFFLE_SERVICE_ENABLED)) { + val kubernetesClient = SparkKubernetesClientFactory.getDriverKubernetesClient( + conf, conf.get("spark.master")) + val pollForPodsExecutor = ThreadUtils.newDaemonThreadPoolScheduledExecutor( + "poll-shuffle-service-pods", 1) + val shuffleServiceLabels = conf.getAllWithPrefix(KUBERNETES_BACKUP_SHUFFLE_SERVICE_LABELS) + val shuffleServicePodsNamespace = conf.get(KUBERNETES_BACKUP_SHUFFLE_SERVICE_PODS_NAMESPACE) + require(shuffleServicePodsNamespace.isDefined, "Namespace for the pods running the backup" + + s" shuffle service must be defined by" + + s" ${KUBERNETES_BACKUP_SHUFFLE_SERVICE_PODS_NAMESPACE.key}") + require(shuffleServiceLabels.nonEmpty, "Requires labels for the backup shuffle service pods.") + + val port: Int = conf.get(KUBERNETES_BACKUP_SHUFFLE_SERVICE_PORT) + new KubernetesShuffleServiceAddressProvider( + kubernetesClient, + pollForPodsExecutor, + shuffleServiceLabels.toMap, + shuffleServicePodsNamespace.get, + port) + } else DefaultShuffleServiceAddressProvider + } +} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 859aa836a3157..918606e40112e 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.shuffle.protocol.BlockTransferMessage -import org.apache.spark.network.shuffle.protocol.mesos.{RegisterDriver, ShuffleServiceHeartbeat} +import org.apache.spark.network.shuffle.protocol.{RegisterDriver, ShuffleServiceHeartbeat} import org.apache.spark.network.util.TransportConf import org.apache.spark.util.ThreadUtils diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala index da71f8f9e407c..9f7336af3f8ea 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala @@ -20,6 +20,8 @@ package org.apache.spark.scheduler.cluster.mesos import org.apache.spark.SparkContext import org.apache.spark.internal.config._ import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.DefaultShuffleServiceAddressProvider +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProvider} /** * Cluster Manager for creation of Mesos scheduler and backend @@ -60,5 +62,8 @@ private[spark] class MesosClusterManager extends ExternalClusterManager { override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) } + + override def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = + DefaultShuffleServiceAddressProvider } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala index 64cd1bd088001..f924088913fa4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProvider} /** * Cluster Manager for creation of Yarn scheduler and backend @@ -53,4 +54,7 @@ private[spark] class YarnClusterManager extends ExternalClusterManager { override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) } + + override def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = + DefaultShuffleServiceAddressProvider } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index fe65353b9d502..232e1c9828e5b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -24,7 +24,6 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag - import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ @@ -39,6 +38,7 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{KryoSerializer, SerializerManager} +import org.apache.spark.shuffle.DefaultShuffleServiceAddressProvider import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ @@ -70,7 +70,8 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) val streamId = 1 val securityMgr = new SecurityManager(conf, encryptionKey) val broadcastManager = new BroadcastManager(true, conf, securityMgr) - val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) + val mapOutputTracker = new MapOutputTrackerMaster( + conf, broadcastManager, true, None, DefaultShuffleServiceAddressProvider) val shuffleManager = new SortShuffleManager(conf) val serializer = new KryoSerializer(conf) var serializerManager = new SerializerManager(serializer, conf, encryptionKey) From 1ff391c6802aa9d447d330875fdcbe3086cae516 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Thu, 29 Nov 2018 23:06:54 -0800 Subject: [PATCH 3/4] WIP tinkering --- .../shuffle/ExternalShuffleBlockResolver.java | 3 + .../shuffle/protocol/ExecutorShuffleInfo.java | 33 +++- .../org/apache/spark/MapOutputTracker.scala | 4 +- .../scala/org/apache/spark/SparkEnv.scala | 5 +- .../spark/internal/config/package.scala | 6 + .../shuffle/IndexShuffleBlockResolver.scala | 26 ++- .../BackingUpShuffleWriter.scala | 8 +- .../ExternalFallbackShuffleClient.scala | 2 +- .../shuffle/external/ShuffleDataIO.scala | 4 +- .../external/ShufflePartitionReader.scala | 8 +- .../external/ShufflePartitionWriter.scala | 8 +- .../shuffle/external/ShuffleReadSupport.scala | 4 +- .../ShuffleServiceAddressProvider.scala | 2 +- ...ShuffleServiceAddressProviderFactory.scala | 3 +- .../external/ShuffleWriteSupport.scala | 4 +- .../default/DefaultShuffleDataIO.scala | 53 ++++++ .../shuffle/sort/SortShuffleManager.scala | 5 +- .../apache/spark/storage/BlockManager.scala | 71 ++++---- .../apache/spark/storage/BlockMapper.scala | 26 +++ .../{DiskStore.scala => BlockStore.scala} | 157 ++++++++++-------- .../spark/storage/DiskBlockManager.scala | 18 +- .../spark/storage/RemoteBlockManager.scala | 60 +++++++ .../apache/spark/MapOutputTrackerSuite.scala | 4 +- .../spark/scheduler/DAGSchedulerSuite.scala | 4 +- .../BlockManagerReplicationSuite.scala | 3 +- .../spark/storage/BlockManagerSuite.scala | 19 +-- .../apache/spark/storage/DiskStoreSuite.scala | 10 +- ...rnal.ShuffleServiceAddressProviderFactory} | 0 ...ernetesShuffleServiceAddressProvider.scala | 4 +- ...ShuffleServiceAddressProviderFactory.scala | 3 +- .../streaming/ReceivedBlockHandlerSuite.scala | 3 +- 31 files changed, 398 insertions(+), 162 deletions(-) rename core/src/main/scala/org/apache/spark/shuffle/{ => external}/BackingUpShuffleWriter.scala (94%) rename core/src/main/scala/org/apache/spark/shuffle/{ => external}/ExternalFallbackShuffleClient.scala (98%) rename core/src/main/scala/org/apache/spark/shuffle/{ => external}/ShuffleServiceAddressProvider.scala (96%) rename core/src/main/scala/org/apache/spark/shuffle/{ => external}/ShuffleServiceAddressProviderFactory.scala (95%) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/external/default/DefaultShuffleDataIO.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockMapper.scala rename core/src/main/scala/org/apache/spark/storage/{DiskStore.scala => BlockStore.scala} (70%) create mode 100644 core/src/main/scala/org/apache/spark/storage/RemoteBlockManager.scala rename resource-managers/kubernetes/core/src/main/resources/META-INF/services/{org.apache.spark.shuffle.ShuffleServiceAddressProviderFactory => org.apache.spark.shuffle.external.ShuffleServiceAddressProviderFactory} (100%) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 6ee40b5dc0a6e..8d9f3ecf0fcbb 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -201,7 +201,10 @@ public void registerExecutorForBackups(String appId, String execId, String shuff "Unsupported shuffle manager of executor: %s.", fullId)); } + // TODO: Finish ExecutorShuffleInfo backupShuffleInfo = new ExecutorShuffleInfo( + new String[] { }, + new String[] { }, new String[] { executorBackupDir.getAbsolutePath() }, 1, shuffleManager); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index 93758bdc58fb0..c3e974dd6d773 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -29,6 +29,10 @@ /** Contains all configuration necessary for locating the shuffle files of an executor. */ public class ExecutorShuffleInfo implements Encodable { + /** The set of remote hosts that the executor stores its back shuffle files in. */ + public final String[] backupHosts; + /** The set of remote ports that the executor stores its back shuffle files in. */ + public final String[] backupPorts; /** The base set of local directories that the executor stores its shuffle files in. */ public final String[] localDirs; /** Number of subdirectories created within each localDir. */ @@ -38,9 +42,14 @@ public class ExecutorShuffleInfo implements Encodable { @JsonCreator public ExecutorShuffleInfo( + @JsonProperty("backupHosts") String [] backupHosts, + @JsonProperty("backupPorts") String[] backupPorts, @JsonProperty("localDirs") String[] localDirs, @JsonProperty("subDirsPerLocalDir") int subDirsPerLocalDir, @JsonProperty("shuffleManager") String shuffleManager) { + this.backupHosts = backupHosts; + this.backupPorts = backupPorts; + assert(backupHosts.length == backupPorts.length); this.localDirs = localDirs; this.subDirsPerLocalDir = subDirsPerLocalDir; this.shuffleManager = shuffleManager; @@ -48,12 +57,15 @@ public ExecutorShuffleInfo( @Override public int hashCode() { - return Objects.hashCode(subDirsPerLocalDir, shuffleManager) * 41 + Arrays.hashCode(localDirs); + return Objects.hashCode(subDirsPerLocalDir, shuffleManager) * 41 + Arrays.hashCode(localDirs) + + Arrays.hashCode(backupHosts) + Arrays.hashCode(backupPorts); } @Override public String toString() { return Objects.toStringHelper(this) + .add("backupHosts", Arrays.toString(backupHosts)) + .add("backupPorts", Arrays.toString(backupPorts)) .add("localDirs", Arrays.toString(localDirs)) .add("subDirsPerLocalDir", subDirsPerLocalDir) .add("shuffleManager", shuffleManager) @@ -64,7 +76,9 @@ public String toString() { public boolean equals(Object other) { if (other != null && other instanceof ExecutorShuffleInfo) { ExecutorShuffleInfo o = (ExecutorShuffleInfo) other; - return Arrays.equals(localDirs, o.localDirs) + return Arrays.equals(backupHosts, o.backupHosts) + && Arrays.equals(backupPorts, o.backupPorts) + && Arrays.equals(localDirs, o.localDirs) && Objects.equal(subDirsPerLocalDir, o.subDirsPerLocalDir) && Objects.equal(shuffleManager, o.shuffleManager); } @@ -73,22 +87,33 @@ public boolean equals(Object other) { @Override public int encodedLength() { - return Encoders.StringArrays.encodedLength(localDirs) + return Encoders.StringArrays.encodedLength(backupHosts) + + Encoders.StringArrays.encodedLength(backupPorts) + + Encoders.StringArrays.encodedLength(localDirs) + 4 // int + Encoders.Strings.encodedLength(shuffleManager); } @Override public void encode(ByteBuf buf) { + Encoders.StringArrays.encode(buf, backupHosts); + Encoders.StringArrays.encode(buf, backupPorts); Encoders.StringArrays.encode(buf, localDirs); buf.writeInt(subDirsPerLocalDir); Encoders.Strings.encode(buf, shuffleManager); } public static ExecutorShuffleInfo decode(ByteBuf buf) { + String[] backupHosts = Encoders.StringArrays.decode(buf); + String[] backupPorts = Encoders.StringArrays.decode(buf); String[] localDirs = Encoders.StringArrays.decode(buf); int subDirsPerLocalDir = buf.readInt(); String shuffleManager = Encoders.Strings.decode(buf); - return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + return new ExecutorShuffleInfo( + backupHosts, + backupPorts, + localDirs, + subDirsPerLocalDir, + shuffleManager); } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index a33c7cadea5fb..9f1c20ff9e92b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -28,13 +28,13 @@ import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.reflect.ClassTag import scala.util.control.NonFatal - import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.{MetadataFetchFailedException, ShuffleServiceAddressProvider} +import org.apache.spark.shuffle.MetadataFetchFailedException +import org.apache.spark.shuffle.external.ShuffleServiceAddressProvider import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 50a7c80d73474..b3587326caae1 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -22,10 +22,10 @@ import java.net.Socket import java.util.{Locale, ServiceLoader} import com.google.common.collect.MapMaker + import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Properties - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager @@ -39,7 +39,8 @@ import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} -import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleManager, ShuffleServiceAddressProviderFactory} +import org.apache.spark.shuffle.external.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProviderFactory} +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ import org.apache.spark.util.{RpcUtils, Utils} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d3fae7e4eb4f7..25064ea6e496e 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -108,6 +108,12 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("10s") + private[spark] val SHUFFLE_REMOTE_READ_OVERRIDE = + ConfigBuilder("spark.shuffle.externalShuffleBackup.remote") + .booleanConf + .createWithDefault(false) + + private[spark] val EXECUTOR_JAVA_OPTIONS = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS).stringConf.createOptional diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index e5aad891541f6..b611c9bdcc93b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -27,6 +27,7 @@ import org.apache.spark.io.NioBufferedFileInputStream import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID +import org.apache.spark.shuffle.external._ import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -43,14 +44,28 @@ import org.apache.spark.util.Utils // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData(). private[spark] class IndexShuffleBlockResolver( conf: SparkConf, - _blockManager: BlockManager = null) + _blockManager: BlockManager = null, + shuffleDataIO: ShuffleDataIO = null) extends ShuffleBlockResolver with Logging { + private lazy val appId = conf.getAppId private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager) private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + private var shuffleWriteSupport: ShuffleWriteSupport = _ + + private var shuffleReadSupport: ShuffleReadSupport = _ + + private var isExternalFileSystem = false + + if (shuffleDataIO != null) { + shuffleWriteSupport = shuffleDataIO.writeSupport() + shuffleReadSupport = shuffleDataIO.readSupport() + isExternalFileSystem = true + } + def getDataFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } @@ -219,6 +234,15 @@ private[spark] class IndexShuffleBlockResolver( offset, nextOffset - offset) } finally { + if (isExternalFileSystem) { + val writer = shuffleWriteSupport.newPartitionWriter(appId, blockId.shuffleId, blockId.mapId) + try { + writer.appendPartition(blockId.reduceId, in) + } catch { + case e: Exception => + writer.abort(e) + } + } in.close() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BackingUpShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/external/BackingUpShuffleWriter.scala similarity index 94% rename from core/src/main/scala/org/apache/spark/shuffle/BackingUpShuffleWriter.scala rename to core/src/main/scala/org/apache/spark/shuffle/external/BackingUpShuffleWriter.scala index 4b8427d78b1f1..c69e5905604eb 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BackingUpShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/external/BackingUpShuffleWriter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle +package org.apache.spark.shuffle.external import java.io.File import java.nio.ByteBuffer @@ -32,6 +32,7 @@ import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, UploadShuffleFileStream, UploadShuffleIndexFileStream} import org.apache.spark.network.util.TransportConf import org.apache.spark.scheduler.{MapStatus, RelocatedMapStatus} +import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter} import org.apache.spark.storage.BlockManagerId class BackingUpShuffleWriter[K, V]( @@ -46,11 +47,13 @@ class BackingUpShuffleWriter[K, V]( appId: String, execId: String, shuffleId: Int, - mapId: Int) + mapId: Int, + backupShuffleDataIO: ShuffleDataIO = null) extends ShuffleWriter[K, V] with Logging { private implicit val backupExecutorContext = ExecutionContext.fromExecutorService(backupExecutor) + private val writeSupport = backupShuffleDataIO.writeSupport() /** Write a sequence of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { delegateWriter.write(records) @@ -90,6 +93,7 @@ class BackingUpShuffleWriter[K, V]( private def backupFile( fileToBackUp: File, backupFileRequest: BlockTransferMessage) { + backupShuffleDataIO.writeSupport() val dataFileBuffer = new FileSegmentManagedBuffer( transportConf, fileToBackUp, 0, fileToBackUp.length()) val uploadBackupRequestBuffer = new NioManagedBuffer(backupFileRequest.toByteBuffer) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ExternalFallbackShuffleClient.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ExternalFallbackShuffleClient.scala similarity index 98% rename from core/src/main/scala/org/apache/spark/shuffle/ExternalFallbackShuffleClient.scala rename to core/src/main/scala/org/apache/spark/shuffle/external/ExternalFallbackShuffleClient.scala index 1b0a6941b99d0..373b8a3060d33 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ExternalFallbackShuffleClient.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ExternalFallbackShuffleClient.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle +package org.apache.spark.shuffle.external import org.apache.spark.network.BlockTransferService import org.apache.spark.network.shuffle._ diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleDataIO.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleDataIO.scala index 68fb80c4e8010..fd4b35e42ed79 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleDataIO.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleDataIO.scala @@ -17,9 +17,7 @@ package org.apache.spark.shuffle.external -private[spark] - -trait ShuffleDataIO { +private[spark] trait ShuffleDataIO { def writeSupport(): ShuffleWriteSupport def readSupport(): ShuffleReadSupport } diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionReader.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionReader.scala index 354f452090237..44a1c74de5458 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionReader.scala @@ -19,9 +19,7 @@ package org.apache.spark.shuffle.external import java.io.InputStream -private[spark] - -// TODO: Support batch-fetch -trait ShufflePartitionReader { - def fetchPartition(reduceId: Int): InputStream +// TODO: Support batch +private[spark] trait ShufflePartitionReader { + def fetchPartition(reduceId: Long): InputStream } diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala index 33b9f2f8428d2..0250eb4659dde 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala @@ -19,11 +19,7 @@ package org.apache.spark.shuffle.external import java.io.{Closeable, InputStream} -private[spark] - -trait ShufflePartitionWriter extends Closeable { - // Reduce ID == PartitionID ? - def appendPartition(partitionId: Int, partitionInput: InputStream): Unit - +private[spark] trait ShufflePartitionWriter extends Closeable { + def appendPartition(partitionId: Long, partitionInput: InputStream): Unit def abort(exception: Throwable): Unit } diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleReadSupport.scala index 01d54c9953aab..87772d73d7211 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleReadSupport.scala @@ -17,8 +17,6 @@ package org.apache.spark.shuffle.external -private[spark] - -trait ShuffleReadSupport { +private[spark] trait ShuffleReadSupport { def newPartitionReader(appId: String, shuffleId: Int, mapId: Int): ShufflePartitionReader } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleServiceAddressProvider.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala rename to core/src/main/scala/org/apache/spark/shuffle/external/ShuffleServiceAddressProvider.scala index eb86eca47e538..8f5f7634f8161 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleServiceAddressProvider.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle +package org.apache.spark.shuffle.external trait ShuffleServiceAddressProvider { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleServiceAddressProviderFactory.scala similarity index 95% rename from core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala rename to core/src/main/scala/org/apache/spark/shuffle/external/ShuffleServiceAddressProviderFactory.scala index abfe7dc156b25..7379d9948cce9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleServiceAddressProviderFactory.scala @@ -14,7 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.shuffle + +package org.apache.spark.shuffle.external import org.apache.spark.SparkConf diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleWriteSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleWriteSupport.scala index f985cc1853142..ffc0a75be4764 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleWriteSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleWriteSupport.scala @@ -17,8 +17,6 @@ package org.apache.spark.shuffle.external -private[spark] - -trait ShuffleWriteSupport { +private[spark] trait ShuffleWriteSupport { def newPartitionWriter(appId: String, shuffleId: Int, mapId: Int): ShufflePartitionWriter } diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/default/DefaultShuffleDataIO.scala b/core/src/main/scala/org/apache/spark/shuffle/external/default/DefaultShuffleDataIO.scala new file mode 100644 index 0000000000000..676c2e083f396 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/default/DefaultShuffleDataIO.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external.default + +import java.io.{DataInputStream, File, InputStream, OutputStream} + +import org.apache.spark.io.NioBufferedFileInputStream +import org.apache.spark.shuffle.external._ + +private[spark] object DefaultShuffleDataIO extends ShuffleDataIO { + override def writeSupport(): ShuffleWriteSupport = { + + val shuffleWriteSupport = new ShuffleWriteSupport { + override def newPartitionWriter(appId: String, shuffleId: Int, mapId: Int): ShufflePartitionWriter = ?? + + new ShufflePartitionWriter { + override def appendPartition(partitionId: Long, partitionInput: OutputStream): Unit = ??? + + override def close(): Unit = ??? + } + } + } + + override def readSupport(): ShuffleReadSupport = new ShuffleReadSupport { + override def newPartitionReader(appId: String, shuffleId: Int, mapId: Int): ShufflePartitionReader = + new ShufflePartitionReader { + override def fetchPartition(reduceId: Long): InputStream = ??? + + override def getDataFile(shuffleId: Int, mapId: Int): File = ??? + + override def getIndexFile(shuffleId: Int, mapId: Int): File = ??? + + override def deleteDataFile(shuffleId: Int, mapId: Int): Unit = ??? + + override def deleteIndexFile(shuffleId: Int, mapId: Int): Unit = ??? + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index f50f9b18e36e2..7c7f0d89958ae 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -21,13 +21,13 @@ import java.net.URI import java.util.concurrent.ConcurrentHashMap import scala.util.Random - import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.server.NoOpRpcHandler import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.external.BackingUpShuffleWriter import org.apache.spark.util.ThreadUtils /** @@ -178,7 +178,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager addressAndClient._3.createClient( addressAndClient._2.getHost, addressAndClient._2.getPort) new BackingUpShuffleWriter( - shuffleBlockResolver, + new IndexShuffleBlockResolver(conf, shuffleDataIO = new DefaultShuffleDataIO()), baseWriter, transportClient, backupShuffleTransportConf, @@ -186,6 +186,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager ThreadUtils.newDaemonCachedThreadPool("backup-shuffle-files"), addressAndClient._1._1, addressAndClient._1._2, + new DefaultShuffleDataIO(), conf.getAppId, env.blockManager.blockManagerId.executorId, handle.shuffleId, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 2994911cdcef3..b4a214de81fb5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,13 +18,14 @@ package org.apache.spark.storage import java.io._ -import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference} +import java.lang.ref.{WeakReference, ReferenceQueue => JReferenceQueue} import java.nio.ByteBuffer import java.nio.channels.Channels import java.util.Collections import java.util.concurrent.ConcurrentHashMap import com.codahale.metrics.{MetricRegistry, MetricSet} + import scala.collection.mutable import scala.collection.mutable.HashMap import scala.concurrent.{ExecutionContext, Future} @@ -32,10 +33,9 @@ import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal - import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.{Logging, config} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source import org.apache.spark.network._ @@ -48,7 +48,8 @@ import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.{ExternalFallbackShuffleClient, ShuffleManager} +import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.external.{ExternalFallbackShuffleClient, ShuffleDataIO} import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ @@ -126,11 +127,16 @@ private[spark] class BlockManager( shuffleManager: ShuffleManager, val blockTransferService: BlockTransferService, securityManager: SecurityManager, - numUsableCores: Int) + numUsableCores: Int, + shuffleDataIO: ShuffleDataIO = null) extends BlockDataManager with BlockEvictionHandler with Logging { private[spark] val externalShuffleServiceEnabled = conf.get(config.SHUFFLE_SERVICE_ENABLED) + + private[spark] val readFromRemote = + conf.get(config.SHUFFLE_REMOTE_READ_OVERRIDE) + private val remoteReadNioBufferConversion = conf.getBoolean("spark.network.remoteReadNioBufferConversion", false) @@ -141,6 +147,13 @@ private[spark] class BlockManager( new DiskBlockManager(conf, deleteFilesOnStop) } + val blockMapper: BlockMapper = + if (readFromRemote) { + new RemoteBlockManager(conf, shuffleDataIO.readSupport()) + } else { + diskBlockManager + } + // Visible for testing private[storage] val blockInfoManager = new BlockInfoManager @@ -150,7 +163,7 @@ private[spark] class BlockManager( // Actual storage of where blocks are kept private[spark] val memoryStore = new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) - private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager) + private[spark] val blockStore = new BlockStore(conf, blockMapper, securityManager) memoryManager.setMemoryStore(memoryStore) // Note: depending on the memory manager, `maxMemory` may actually vary over time. @@ -324,6 +337,8 @@ private[spark] class BlockManager( private def registerWithExternalShuffleServer() { logInfo("Registering executor with local external shuffle service.") val shuffleConfig = new ExecutorShuffleInfo( + backupShuffleServiceAddresses.map(_._1).toArray, + backupShuffleServiceAddresses.map(_._2.toString).toArray, diskBlockManager.localDirs.map(_.toString), diskBlockManager.subDirsPerLocalDir, shuffleManager.getClass.getName) @@ -543,7 +558,7 @@ private[spark] class BlockManager( def getStatus(blockId: BlockId): Option[BlockStatus] = { blockInfoManager.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L - val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L + val diskSize = if (blockStore.contains(blockId)) blockStore.getSize(blockId) else 0L BlockStatus(info.level, memSize = memSize, diskSize = diskSize) } } @@ -611,7 +626,7 @@ private[spark] class BlockManager( BlockStatus.empty case level => val inMem = level.useMemory && memoryStore.contains(blockId) - val onDisk = level.useDisk && diskStore.contains(blockId) + val onDisk = level.useDisk && blockStore.contains(blockId) val deserialized = if (inMem) level.deserialized else false val replication = if (inMem || onDisk) level.replication else 1 val storageLevel = StorageLevel( @@ -621,7 +636,7 @@ private[spark] class BlockManager( deserialized = deserialized, replication = replication) val memSize = if (inMem) memoryStore.getSize(blockId) else 0L - val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L + val diskSize = if (onDisk) blockStore.getSize(blockId) else 0L BlockStatus(storageLevel, memSize, diskSize) } } @@ -675,8 +690,8 @@ private[spark] class BlockManager( releaseLock(blockId, taskAttemptId) }) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) - } else if (level.useDisk && diskStore.contains(blockId)) { - val diskData = diskStore.getBytes(blockId) + } else if (level.useDisk && blockStore.contains(blockId)) { + val diskData = blockStore.getBytes(blockId) val iterToReturn: Iterator[Any] = { if (level.deserialized) { val diskValues = serializerManager.dataDeserializeStream( @@ -732,12 +747,12 @@ private[spark] class BlockManager( // serializing in-memory objects, and, finally, throw an exception if the block does not exist. if (level.deserialized) { // Try to avoid expensive serialization by reading a pre-serialized copy from disk: - if (level.useDisk && diskStore.contains(blockId)) { + if (level.useDisk && blockStore.contains(blockId)) { // Note: we purposely do not try to put the block back into memory here. Since this branch // handles deserialized blocks, this block may only be cached in memory as objects, not // serialized bytes. Because the caller only requested bytes, it doesn't make sense to // cache the block's deserialized objects since that caching may not have a payoff. - diskStore.getBytes(blockId) + blockStore.getBytes(blockId) } else if (level.useMemory && memoryStore.contains(blockId)) { // The block was not found on disk, so serialize an in-memory copy: new ByteBufferBlockData(serializerManager.dataSerializeWithExplicitClassTag( @@ -748,8 +763,8 @@ private[spark] class BlockManager( } else { // storage level is serialized if (level.useMemory && memoryStore.contains(blockId)) { new ByteBufferBlockData(memoryStore.getBytes(blockId).get, false) - } else if (level.useDisk && diskStore.contains(blockId)) { - val diskData = diskStore.getBytes(blockId) + } else if (level.useDisk && blockStore.contains(blockId)) { + val diskData = blockStore.getBytes(blockId) maybeCacheDiskBytesInMemory(info, blockId, level, diskData) .map(new ByteBufferBlockData(_, false)) .getOrElse(diskData) @@ -1093,10 +1108,10 @@ private[spark] class BlockManager( } if (!putSucceeded && level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.putBytes(blockId, bytes) + blockStore.putBytes(blockId, bytes) } } else if (level.useDisk) { - diskStore.putBytes(blockId, bytes) + blockStore.putBytes(blockId, bytes) } val putBlockStatus = getCurrentBlockStatus(blockId, info) @@ -1242,11 +1257,11 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { channel => + blockStore.put(blockId) { channel => val out = Channels.newOutputStream(channel) serializerManager.dataSerializeStream(blockId, out, iter)(classTag) } - size = diskStore.getSize(blockId) + size = blockStore.getSize(blockId) } else { iteratorFromFailedMemoryStorePut = Some(iter) } @@ -1259,11 +1274,11 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { channel => + blockStore.put(blockId) { channel => val out = Channels.newOutputStream(channel) partiallySerializedValues.finishWritingToStream(out) } - size = diskStore.getSize(blockId) + size = blockStore.getSize(blockId) } else { iteratorFromFailedMemoryStorePut = Some(partiallySerializedValues.valuesIterator) } @@ -1271,11 +1286,11 @@ private[spark] class BlockManager( } } else if (level.useDisk) { - diskStore.put(blockId) { channel => + blockStore.put(blockId) { channel => val out = Channels.newOutputStream(channel) serializerManager.dataSerializeStream(blockId, out, iterator())(classTag) } - size = diskStore.getSize(blockId) + size = blockStore.getSize(blockId) } val putBlockStatus = getCurrentBlockStatus(blockId, info) @@ -1574,11 +1589,11 @@ private[spark] class BlockManager( val level = info.level // Drop to disk, if storage level requires - if (level.useDisk && !diskStore.contains(blockId)) { + if (level.useDisk && !blockStore.contains(blockId)) { logInfo(s"Writing block $blockId to disk") data() match { case Left(elements) => - diskStore.put(blockId) { channel => + blockStore.put(blockId) { channel => val out = Channels.newOutputStream(channel) serializerManager.dataSerializeStream( blockId, @@ -1586,7 +1601,7 @@ private[spark] class BlockManager( elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]]) } case Right(bytes) => - diskStore.putBytes(blockId, bytes) + blockStore.putBytes(blockId, bytes) } blockIsUpdated = true } @@ -1658,7 +1673,7 @@ private[spark] class BlockManager( private def removeBlockInternal(blockId: BlockId, tellMaster: Boolean): Unit = { // Removals are idempotent in disk store and memory store. At worst, we get a warning. val removedFromMemory = memoryStore.remove(blockId) - val removedFromDisk = diskStore.remove(blockId) + val removedFromDisk = blockStore.remove(blockId) if (!removedFromMemory && !removedFromDisk) { logWarning(s"Block $blockId could not be removed as it was not found on disk or in memory") } @@ -1693,7 +1708,7 @@ private[spark] class BlockManager( shuffleClient.close() } remoteBlockTempFileManager.stop() - diskBlockManager.stop() + blockMapper.stop() rpcEnv.stop(slaveEndpoint) blockInfoManager.clear() memoryStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMapper.scala b/core/src/main/scala/org/apache/spark/storage/BlockMapper.scala new file mode 100644 index 0000000000000..094ab8487cb0b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockMapper.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +/** + * Creates and maintains the logical mapping between logical blocks and physical on-disk + * locations. One block is mapped to one file with a name given by its BlockId. + */ +private[spark] trait BlockMapper { + def containsBlock(blockId: BlockId): Boolean +} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala similarity index 70% rename from core/src/main/scala/org/apache/spark/storage/DiskStore.scala rename to core/src/main/scala/org/apache/spark/storage/BlockStore.scala index 29963a95cb074..db1a4a8c7d2f1 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -28,7 +28,7 @@ import scala.collection.mutable.ListBuffer import com.google.common.io.Closeables import io.netty.channel.DefaultFileRegion -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} @@ -38,11 +38,11 @@ import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBuffer /** - * Stores BlockManager blocks on disk. + * Stores BlockManager blocks. */ -private[spark] class DiskStore( +private[spark] class BlockStore( conf: SparkConf, - diskManager: DiskBlockManager, + blockMapper: BlockMapper, securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") @@ -60,35 +60,40 @@ private[spark] class DiskStore( if (contains(blockId)) { throw new IllegalStateException(s"Block $blockId is already present in the disk store") } - logDebug(s"Attempting to put block $blockId") - val startTime = System.currentTimeMillis - val file = diskManager.getFile(blockId) - val out = new CountingWritableChannel(openForWrite(file)) - var threwException: Boolean = true - try { - writeFunc(out) - blockSizes.put(blockId, out.getCount) - threwException = false - } finally { - try { - out.close() - } catch { - case ioe: IOException => - if (!threwException) { - threwException = true - throw ioe + blockMapper match { + case _: RemoteBlockManager => + throw new IllegalAccessError("Remote Block Mapper does not support this writing feature") + case d: DiskBlockManager => + logDebug(s"Attempting to put block $blockId") + val startTime = System.currentTimeMillis + val file = d.getFile(blockId) + val out = new CountingWritableChannel(openForWrite(file)) + var threwException: Boolean = true + try { + writeFunc(out) + blockSizes.put(blockId, out.getCount) + threwException = false + } finally { + try { + out.close() + } catch { + case ioe: IOException => + if (!threwException) { + threwException = true + throw ioe + } + } finally { + if (threwException) { + remove(blockId) + } } - } finally { - if (threwException) { - remove(blockId) } - } + val finishTime = System.currentTimeMillis + logDebug("Block %s stored as %s file on disk in %d ms".format( + file.getName, + Utils.bytesToString(file.length()), + finishTime - startTime)) } - val finishTime = System.currentTimeMillis - logDebug("Block %s stored as %s file on disk in %d ms".format( - file.getName, - Utils.bytesToString(file.length()), - finishTime - startTime)) } def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = { @@ -98,53 +103,60 @@ private[spark] class DiskStore( } def getBytes(blockId: BlockId): BlockData = { - val file = diskManager.getFile(blockId.name) val blockSize = getSize(blockId) securityManager.getIOEncryptionKey() match { case Some(key) => // Encrypted blocks cannot be memory mapped; return a special object that does decryption // and provides InputStream / FileRegion implementations for reading the data. - new EncryptedBlockData(file, blockSize, conf, key) - + blockMapper match { + case d: DiskBlockManager => + new EncryptedBlockData(d.getFile(blockId), blockSize, conf, key) + case r: RemoteBlockManager => + new EncryptedBlockData(null, blockSize, conf, key, r.getInputStream(blockId)) + } case _ => - new DiskBlockData(minMemoryMapBytes, maxMemoryMapBytes, file, blockSize) + blockMapper match { + case d: DiskBlockManager => + new DiskBlockData(minMemoryMapBytes, maxMemoryMapBytes, d.getFile(blockId), blockSize) + case _: RemoteBlockManager => + throw new SparkException("Cant read from non-encrypted remote block") + } } } def remove(blockId: BlockId): Boolean = { - blockSizes.remove(blockId) - val file = diskManager.getFile(blockId.name) - if (file.exists()) { - val ret = file.delete() - if (!ret) { - logWarning(s"Error deleting ${file.getPath()}") - } - ret - } else { - false + blockMapper match { + case d: DiskBlockManager => + blockSizes.remove(blockId) + d.removeBlock(blockId) + case _: RemoteBlockManager => + throw new IllegalAccessError("Remote Block Mapper does not support this writing feature") } } def contains(blockId: BlockId): Boolean = { - val file = diskManager.getFile(blockId.name) - file.exists() + blockMapper.containsBlock(blockId) } private def openForWrite(file: File): WritableByteChannel = { - val out = new FileOutputStream(file).getChannel() - try { - securityManager.getIOEncryptionKey().map { key => - CryptoStreamUtils.createWritableChannel(out, conf, key) - }.getOrElse(out) - } catch { - case e: Exception => - Closeables.close(out, true) - file.delete() - throw e + blockMapper match { + case _: DiskBlockManager => + val out = new FileOutputStream(file).getChannel() + try { + securityManager.getIOEncryptionKey().map { key => + CryptoStreamUtils.createWritableChannel(out, conf, key) + }.getOrElse(out) + } catch { + case e: Exception => + Closeables.close(out, true) + file.delete() + throw e + } + case _: RemoteBlockManager => + throw new IllegalAccessError("Remote Block Mapper does not support this writing feature") } } - } private class DiskBlockData( @@ -156,10 +168,10 @@ private class DiskBlockData( override def toInputStream(): InputStream = new FileInputStream(file) /** - * Returns a Netty-friendly wrapper for the block's data. - * - * Please see `ManagedBuffer.convertToNetty()` for more details. - */ + * Returns a Netty-friendly wrapper for the block's data. + * + * Please see `ManagedBuffer.convertToNetty()` for more details. + */ override def toNetty(): AnyRef = new DefaultFileRegion(file, 0, size) override def toChunkedByteBuffer(allocator: (Int) => ByteBuffer): ChunkedByteBuffer = { @@ -181,7 +193,7 @@ private class DiskBlockData( override def toByteBuffer(): ByteBuffer = { require(blockSize < maxMemoryMapBytes, s"can't create a byte buffer of size $blockSize" + - s" since it exceeds ${Utils.bytesToString(maxMemoryMapBytes)}.") + s" since it exceeds ${Utils.bytesToString(maxMemoryMapBytes)}.") Utils.tryWithResource(open()) { channel => if (blockSize < minMemoryMapBytes) { // For small files, directly read rather than memory map. @@ -206,7 +218,8 @@ private[spark] class EncryptedBlockData( file: File, blockSize: Long, conf: SparkConf, - key: Array[Byte]) extends BlockData { + key: Array[Byte], + iStream: InputStream = null) extends BlockData { override def toInputStream(): InputStream = Channels.newInputStream(open()) @@ -254,13 +267,17 @@ private[spark] class EncryptedBlockData( override def dispose(): Unit = { } private def open(): ReadableByteChannel = { - val channel = new FileInputStream(file).getChannel() - try { - CryptoStreamUtils.createReadableChannel(channel, conf, key) - } catch { - case e: Exception => - Closeables.close(channel, true) - throw e + if (iStream != null) { + Channels.newChannel(iStream) + } else { + val channel = new FileInputStream(file).getChannel() + try { + CryptoStreamUtils.createReadableChannel(channel, conf, key) + } catch { + case e: Exception => + Closeables.close(channel, true) + throw e + } } } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index a69bcc9259995..8d765910023fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -32,7 +32,8 @@ import org.apache.spark.util.{ShutdownHookManager, Utils} * Block files are hashed among the directories listed in spark.local.dir (or in * SPARK_LOCAL_DIRS, if it's set). */ -private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolean) extends Logging { +private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolean) + extends Logging with BlockMapper { private[spark] val subDirsPerLocalDir = conf.getInt("spark.diskStore.subDirectories", 64) @@ -80,10 +81,23 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea def getFile(blockId: BlockId): File = getFile(blockId.name) /** Check if disk block manager has a block. */ - def containsBlock(blockId: BlockId): Boolean = { + override def containsBlock(blockId: BlockId): Boolean = { getFile(blockId.name).exists() } + def removeBlock(blockId: BlockId): Boolean = { + val file = getFile(blockId) + if (file.exists()) { + val ret = file.delete() + if (!ret) { + logWarning(s"Error deleting ${file.getPath()}") + } + ret + } else { + false + } + } + /** List all the files currently stored on disk by the disk manager. */ def getAllFiles(): Seq[File] = { // Get all the files inside the array of array of directories diff --git a/core/src/main/scala/org/apache/spark/storage/RemoteBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/RemoteBlockManager.scala new file mode 100644 index 0000000000000..7bff63248d635 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/RemoteBlockManager.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.InputStream + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.external.ShuffleReadSupport + + /** + * Creates and maintains the logical mapping between logical blocks and physical on-disk + * locations. One block is mapped to one file with a name given by its BlockId. + * + * Block files are hashed among the directories listed in spark.local.dir (or in + * SPARK_LOCAL_DIRS, if it's set). + */ +private[spark] class RemoteBlockManager( + conf: SparkConf, + shuffleReadSupport: ShuffleReadSupport) + extends Logging with BlockMapper { + + def getInputStream(blockId: BlockId): InputStream = { + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + val reader = shuffleReadSupport.newPartitionReader( + conf.getAppId, shufId, mapId) + reader.fetchPartition(reduceId) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block") + } + } + + override def containsBlock(blockId: BlockId): Boolean = + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + val reader = shuffleReadSupport.newPartitionReader( + conf.getAppId, shufId, mapId) + reader.fetchPartition(reduceId).available() > 0 + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block") + } + } \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index e5e97a496b944..119a443a15e15 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer import org.mockito.Matchers.any import org.mockito.Mockito._ - import org.apache.spark.LocalSparkContext._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} -import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, FetchFailedException} +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.external.DefaultShuffleServiceAddressProvider import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index adad66e788ce7..eeec7f684c668 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -26,14 +26,14 @@ import scala.language.reflectiveCalls import scala.util.control.NonFatal import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ - import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode -import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, FetchFailedException, MetadataFetchFailedException} +import org.apache.spark.shuffle.external.DefaultShuffleServiceAddressProvider +import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index a176ca100807f..6f81a3c7482a2 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -26,7 +26,6 @@ import scala.language.postfixOps import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ - import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging @@ -37,7 +36,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} -import org.apache.spark.shuffle.DefaultShuffleServiceAddressProvider +import org.apache.spark.shuffle.external.DefaultShuffleServiceAddressProvider import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ import org.apache.spark.util.Utils diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 692563cc8192f..55c66b8e04200 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -31,7 +31,6 @@ import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest._ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ - import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod @@ -49,7 +48,7 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} -import org.apache.spark.shuffle.DefaultShuffleServiceAddressProvider +import org.apache.spark.shuffle.external.DefaultShuffleServiceAddressProvider import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util._ @@ -1018,7 +1017,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE case _ => fail("Updated block is neither list2 nor list4") } } - assert(store.diskStore.contains("list2"), "list2 was not in disk store") + assert(store.blockStore.contains("list2"), "list2 was not in disk store") assert(store.memoryStore.contains("list4"), "list4 was not in memory store") // No updated blocks - list5 is too big to fit in store and nothing is kicked out @@ -1036,11 +1035,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!store.memoryStore.contains("list5"), "list5 was in memory store") // disk store contains only list2 - assert(!store.diskStore.contains("list1"), "list1 was in disk store") - assert(store.diskStore.contains("list2"), "list2 was not in disk store") - assert(!store.diskStore.contains("list3"), "list3 was in disk store") - assert(!store.diskStore.contains("list4"), "list4 was in disk store") - assert(!store.diskStore.contains("list5"), "list5 was in disk store") + assert(!store.blockStore.contains("list1"), "list1 was in disk store") + assert(store.blockStore.contains("list2"), "list2 was not in disk store") + assert(!store.blockStore.contains("list3"), "list3 was in disk store") + assert(!store.blockStore.contains("list4"), "list4 was in disk store") + assert(!store.blockStore.contains("list5"), "list5 was in disk store") // remove block - list2 should be removed from disk val updatedBlocks6 = getUpdatedBlocks { @@ -1050,7 +1049,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(updatedBlocks6.size === 1) assert(updatedBlocks6.head._1 === TestBlockId("list2")) assert(updatedBlocks6.head._2.storageLevel == StorageLevel.NONE) - assert(!store.diskStore.contains("list2"), "list2 was in disk store") + assert(!store.blockStore.contains("list2"), "list2 was in disk store") } test("query block statuses") { @@ -1159,7 +1158,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("safely unroll blocks through putIterator (disk)") { store = makeBlockManager(12000) val memoryStore = store.memoryStore - val diskStore = store.diskStore + val diskStore = store.blockStore val smallList = List.fill(40)(new Array[Byte](100)) val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index eec961a491101..461716585f49f 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -48,13 +48,13 @@ class DiskStoreSuite extends SparkFunSuite { val blockId = BlockId("rdd_1_2") val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager, + val diskStoreMapped = new BlockStore(conf.clone().set(confKey, "0"), diskBlockManager, securityManager) diskStoreMapped.putBytes(blockId, byteBuffer) val mapped = diskStoreMapped.getBytes(blockId).toByteBuffer() assert(diskStoreMapped.remove(blockId)) - val diskStoreNotMapped = new DiskStore(conf.clone().set(confKey, "1m"), diskBlockManager, + val diskStoreNotMapped = new BlockStore(conf.clone().set(confKey, "1m"), diskBlockManager, securityManager) diskStoreNotMapped.putBytes(blockId, byteBuffer) val notMapped = diskStoreNotMapped.getBytes(blockId).toByteBuffer() @@ -78,7 +78,7 @@ class DiskStoreSuite extends SparkFunSuite { test("block size tracking") { val conf = new SparkConf() val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) + val diskStore = new BlockStore(conf, diskBlockManager, new SecurityManager(conf)) val blockId = BlockId("rdd_1_2") diskStore.put(blockId) { chan => @@ -97,7 +97,7 @@ class DiskStoreSuite extends SparkFunSuite { val conf = new SparkConf() .set(config.MEMORY_MAP_LIMIT_FOR_TESTS.key, "10k") val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) + val diskStore = new BlockStore(conf, diskBlockManager, new SecurityManager(conf)) val blockId = BlockId("rdd_1_2") diskStore.put(blockId) { chan => @@ -139,7 +139,7 @@ class DiskStoreSuite extends SparkFunSuite { val conf = new SparkConf() val securityManager = new SecurityManager(conf, Some(CryptoStreamUtils.createKey(conf))) val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStore = new DiskStore(conf, diskBlockManager, securityManager) + val diskStore = new BlockStore(conf, diskBlockManager, securityManager) val blockId = BlockId("rdd_1_2") diskStore.put(blockId) { chan => diff --git a/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.shuffle.ShuffleServiceAddressProviderFactory b/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.shuffle.external.ShuffleServiceAddressProviderFactory similarity index 100% rename from resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.shuffle.ShuffleServiceAddressProviderFactory rename to resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.shuffle.external.ShuffleServiceAddressProviderFactory diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala index 00adbc6426a28..959b77f40a14c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala @@ -21,12 +21,12 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import io.fabric8.kubernetes.api.model.Pod import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watch, Watcher} + import scala.collection.JavaConverters._ import scala.collection.mutable - import org.apache.spark.internal.Logging import org.apache.spark.scheduler.cluster.k8s._ -import org.apache.spark.shuffle.ShuffleServiceAddressProvider +import org.apache.spark.shuffle.external.ShuffleServiceAddressProvider import org.apache.spark.util.Utils class KubernetesShuffleServiceAddressProvider( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala index 32f56ee6b4396..5bd0fc410f4ba 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala @@ -19,7 +19,8 @@ package org.apache.spark.shuffle.k8s import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory -import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProvider, ShuffleServiceAddressProviderFactory} +import org.apache.spark.shuffle.external.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProvider, ShuffleServiceAddressProviderFactory} +import org.apache.spark.shuffle.DefaultShuffleServiceAddressProvider import org.apache.spark.util.ThreadUtils class KubernetesShuffleServiceAddressProviderFactory extends ShuffleServiceAddressProviderFactory { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 232e1c9828e5b..324f7e2e16069 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -27,7 +27,6 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ - import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging @@ -38,7 +37,7 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{KryoSerializer, SerializerManager} -import org.apache.spark.shuffle.DefaultShuffleServiceAddressProvider +import org.apache.spark.shuffle.external.DefaultShuffleServiceAddressProvider import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ From f59663f439cc621a956e4002e9b6e3dc7758e41d Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Fri, 7 Dec 2018 12:59:17 -0800 Subject: [PATCH 4/4] musings --- .../shuffle/ExternalShuffleBlockResolver.java | 58 --------- .../shuffle/ExternalShuffleClient.java | 7 +- .../spark/network/shuffle/ShuffleClient.java | 2 +- .../shuffle/protocol/ExecutorShuffleInfo.java | 28 +---- .../RegisterExecutorForBackupsOnly.java | 21 +++- .../shuffle/sort/UnsafeShuffleWriter.java | 59 +++++++++ .../org/apache/spark/MapOutputTracker.scala | 78 ++++-------- .../spark/network/BlockTransferService.scala | 6 +- .../netty/NettyBlockTransferService.scala | 2 +- .../shuffle/BlockStoreShuffleReader.scala | 69 +++++++--- .../shuffle/IndexShuffleBlockResolver.scala | 18 +-- .../external/BackingUpShuffleWriter.scala | 118 ------------------ ...cala => ExternalRemoteShuffleClient.scala} | 16 +-- .../external/ShufflePartitionWriter.scala | 5 +- .../shuffle/sort/SortShuffleManager.scala | 41 +++--- .../shuffle/sort/SortShuffleWriter.scala | 10 +- .../apache/spark/storage/BlockManager.scala | 47 ++++--- .../apache/spark/storage/BlockManagerId.scala | 12 +- .../org/apache/spark/storage/BlockStore.scala | 2 +- .../storage/RemoteBlockObjectWriter.scala | 106 ++++++++++++++++ .../storage/ShuffleBlockFetcherIterator.scala | 4 +- .../util/collection/ExternalSorter.scala | 7 +- .../WritablePartitionedPairCollection.scala | 2 + .../org/apache/spark/DistributedSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 6 +- 25 files changed, 354 insertions(+), 372 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/external/BackingUpShuffleWriter.scala rename core/src/main/scala/org/apache/spark/shuffle/external/{ExternalFallbackShuffleClient.scala => ExternalRemoteShuffleClient.scala} (83%) create mode 100644 core/src/main/scala/org/apache/spark/storage/RemoteBlockObjectWriter.scala diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 8d9f3ecf0fcbb..4304a421e7167 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -103,7 +103,6 @@ public class ExternalShuffleBlockResolver { "org.apache.spark.shuffle.sort.SortShuffleManager", "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"); - private final File shuffleBackupsDir; public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile) throws IOException { @@ -141,8 +140,6 @@ public int weigh(File file, ShuffleIndexInformation indexInfo) { } else { executors = Maps.newConcurrentMap(); } - this.backupExecutors = Maps.newConcurrentMap(); - this.shuffleBackupsDir = Files.createTempDirectory("spark-shuffle-backups").toFile(); this.directoryCleaner = directoryCleaner; } @@ -179,61 +176,6 @@ public void registerExecutor( executors.put(fullId, executorInfo); } - public void registerExecutorForBackups(String appId, String execId, String shuffleManager) { - AppExecId fullId = new AppExecId(appId, execId); - if (executors.containsKey(fullId)) { - throw new UnsupportedOperationException( - String.format( - "Executor %s cannot be registered for both primary shuffle management and backup" + - " shuffle management.", fullId)); - } - File executorBackupDir = Paths.get( - shuffleBackupsDir.getAbsolutePath(), appId, execId).toFile(); - if (!executorBackupDir.mkdirs()) { - throw new RuntimeException( - String.format( - "Failed to create directories for executor backup shuffle files at %s.", - executorBackupDir.getAbsolutePath())); - } - if (!knownManagers.contains(shuffleManager)) { - throw new UnsupportedOperationException( - String.format( - "Unsupported shuffle manager of executor: %s.", fullId)); - } - - // TODO: Finish - ExecutorShuffleInfo backupShuffleInfo = new ExecutorShuffleInfo( - new String[] { }, - new String[] { }, - new String[] { executorBackupDir.getAbsolutePath() }, - 1, - shuffleManager); - logger.info("Registering executor {} with {} for backups.", fullId, backupShuffleInfo); - backupExecutors.put(fullId, backupShuffleInfo); - } - - public StreamCallbackWithID openShuffleFileForBackup( - String appId, String execId, int shuffleId, int mapId) { - return getFileWriterStreamCallback( - appId, - execId, - shuffleId, - mapId, - "data", - FileWriterStreamCallback.BackupFileType.DATA); - } - - public StreamCallbackWithID openShuffleIndexFileForBackup( - String appId, String execId, int shuffleId, int mapId) { - return getFileWriterStreamCallback( - appId, - execId, - shuffleId, - mapId, - "index", - FileWriterStreamCallback.BackupFileType.INDEX); - } - private StreamCallbackWithID getFileWriterStreamCallback( String appId, String execId, diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 531a6c146e6a7..06a8b51f85190 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -92,7 +92,7 @@ public void fetchBlocks( int port, String execId, String[] blockIds, - boolean isBackup, + boolean isRemote, BlockFetchingListener listener, DownloadFileManager downloadFileManager) { checkInit(); @@ -148,7 +148,8 @@ public void registerWithShuffleServer( } } - public void registerWithShuffleServerForBackups( + public void registerWithRemoteShuffleServer( + String driverHostPort, String host, int port, String execId, @@ -156,7 +157,7 @@ public void registerWithShuffleServerForBackups( checkInit(); try (TransportClient client = clientFactory.createUnmanagedClient(host, port)) { ByteBuffer registerMessage = new RegisterExecutorForBackupsOnly( - appId, execId, shuffleManager).toByteBuffer(); + driverHostPort, appId, execId, shuffleManager).toByteBuffer(); client.sendRpcSync(registerMessage, registrationTimeoutMs); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 5263e38d32a8e..5f42cefab9080 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -53,7 +53,7 @@ public abstract void fetchBlocks( int port, String execId, String[] blockIds, - boolean isBackup, + boolean isRemote, BlockFetchingListener listener, DownloadFileManager downloadFileManager); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index c3e974dd6d773..8daaf6b772b64 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -29,10 +29,6 @@ /** Contains all configuration necessary for locating the shuffle files of an executor. */ public class ExecutorShuffleInfo implements Encodable { - /** The set of remote hosts that the executor stores its back shuffle files in. */ - public final String[] backupHosts; - /** The set of remote ports that the executor stores its back shuffle files in. */ - public final String[] backupPorts; /** The base set of local directories that the executor stores its shuffle files in. */ public final String[] localDirs; /** Number of subdirectories created within each localDir. */ @@ -42,14 +38,9 @@ public class ExecutorShuffleInfo implements Encodable { @JsonCreator public ExecutorShuffleInfo( - @JsonProperty("backupHosts") String [] backupHosts, - @JsonProperty("backupPorts") String[] backupPorts, @JsonProperty("localDirs") String[] localDirs, @JsonProperty("subDirsPerLocalDir") int subDirsPerLocalDir, @JsonProperty("shuffleManager") String shuffleManager) { - this.backupHosts = backupHosts; - this.backupPorts = backupPorts; - assert(backupHosts.length == backupPorts.length); this.localDirs = localDirs; this.subDirsPerLocalDir = subDirsPerLocalDir; this.shuffleManager = shuffleManager; @@ -57,15 +48,12 @@ public ExecutorShuffleInfo( @Override public int hashCode() { - return Objects.hashCode(subDirsPerLocalDir, shuffleManager) * 41 + Arrays.hashCode(localDirs) - + Arrays.hashCode(backupHosts) + Arrays.hashCode(backupPorts); + return Objects.hashCode(subDirsPerLocalDir, shuffleManager) * 41 + Arrays.hashCode(localDirs); } @Override public String toString() { return Objects.toStringHelper(this) - .add("backupHosts", Arrays.toString(backupHosts)) - .add("backupPorts", Arrays.toString(backupPorts)) .add("localDirs", Arrays.toString(localDirs)) .add("subDirsPerLocalDir", subDirsPerLocalDir) .add("shuffleManager", shuffleManager) @@ -76,9 +64,7 @@ public String toString() { public boolean equals(Object other) { if (other != null && other instanceof ExecutorShuffleInfo) { ExecutorShuffleInfo o = (ExecutorShuffleInfo) other; - return Arrays.equals(backupHosts, o.backupHosts) - && Arrays.equals(backupPorts, o.backupPorts) - && Arrays.equals(localDirs, o.localDirs) + return Arrays.equals(localDirs, o.localDirs) && Objects.equal(subDirsPerLocalDir, o.subDirsPerLocalDir) && Objects.equal(shuffleManager, o.shuffleManager); } @@ -87,31 +73,23 @@ public boolean equals(Object other) { @Override public int encodedLength() { - return Encoders.StringArrays.encodedLength(backupHosts) - + Encoders.StringArrays.encodedLength(backupPorts) - + Encoders.StringArrays.encodedLength(localDirs) + return Encoders.StringArrays.encodedLength(localDirs) + 4 // int + Encoders.Strings.encodedLength(shuffleManager); } @Override public void encode(ByteBuf buf) { - Encoders.StringArrays.encode(buf, backupHosts); - Encoders.StringArrays.encode(buf, backupPorts); Encoders.StringArrays.encode(buf, localDirs); buf.writeInt(subDirsPerLocalDir); Encoders.Strings.encode(buf, shuffleManager); } public static ExecutorShuffleInfo decode(ByteBuf buf) { - String[] backupHosts = Encoders.StringArrays.decode(buf); - String[] backupPorts = Encoders.StringArrays.decode(buf); String[] localDirs = Encoders.StringArrays.decode(buf); int subDirsPerLocalDir = buf.readInt(); String shuffleManager = Encoders.Strings.decode(buf); return new ExecutorShuffleInfo( - backupHosts, - backupPorts, localDirs, subDirsPerLocalDir, shuffleManager); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorForBackupsOnly.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorForBackupsOnly.java index 3986f9519f81d..7e2806a7b014e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorForBackupsOnly.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorForBackupsOnly.java @@ -6,13 +6,17 @@ import org.apache.spark.network.protocol.Encoders; public class RegisterExecutorForBackupsOnly extends BlockTransferMessage { - + public final String driverHostPort; public final String appId; public final String execId; public final String shuffleManager; public RegisterExecutorForBackupsOnly( - String appId, String execId, String shuffleManager) { + String driverHostPort, + String appId, + String execId, + String shuffleManager) { + this.driverHostPort = driverHostPort; this.appId = appId; this.execId = execId; this.shuffleManager = shuffleManager; @@ -25,13 +29,15 @@ protected Type type() { @Override public int encodedLength() { - return Encoders.Strings.encodedLength(appId) + return Encoders.Strings.encodedLength(driverHostPort) + + Encoders.Strings.encodedLength(appId) + Encoders.Strings.encodedLength(execId) + Encoders.Strings.encodedLength(shuffleManager); } @Override public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, driverHostPort); Encoders.Strings.encode(buf, appId); Encoders.Strings.encode(buf, execId); Encoders.Strings.encode(buf, shuffleManager); @@ -41,7 +47,8 @@ public void encode(ByteBuf buf) { public boolean equals(Object other) { if (other instanceof RegisterExecutorForBackupsOnly) { RegisterExecutorForBackupsOnly o = (RegisterExecutorForBackupsOnly) other; - return Objects.equal(appId, o.appId) + return Objects.equal(driverHostPort, o.driverHostPort) + && Objects.equal(appId, o.appId) && Objects.equal(execId, o.execId) && Objects.equal(shuffleManager, o.shuffleManager); } @@ -50,12 +57,13 @@ public boolean equals(Object other) { @Override public int hashCode() { - return Objects.hashCode(appId, execId, shuffleManager); + return Objects.hashCode(driverHostPort, appId, execId, shuffleManager); } @Override public String toString() { return Objects.toStringHelper(RegisterExecutorForBackupsOnly.class) + .add("driverHostPort", driverHostPort) .add("appId", appId) .add("execId", execId) .add("shuffleManager", shuffleManager) @@ -63,9 +71,10 @@ public String toString() { } public static RegisterExecutorForBackupsOnly decode(ByteBuf buf) { + String driverHostPort = Encoders.Strings.decode(buf); String appId = Encoders.Strings.decode(buf); String execId = Encoders.Strings.decode(buf); String shuffleManager = Encoders.Strings.decode(buf); - return new RegisterExecutorForBackupsOnly(appId, execId, shuffleManager); + return new RegisterExecutorForBackupsOnly(driverHostPort, appId, execId, shuffleManager); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4839d04522f10..5b6f39c37e335 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -424,6 +424,65 @@ private long[] mergeSpillsWithFileStream( return partitionLengths; } + /** + * Merges spill files using the ShufflePartitionWriter API. + */ + private long[] mergeSpillsWithPluggableWriter( + SpillInfo[] spills, + @Nullable CompressionCodec compressionCodec) throws IOException { + assert (spills.length >= 2); + assert(blockManager. != null); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final InputStream[] spillInputStreams = new InputStream[spills.length]; + boolean threwException = true; + try (ShufflePartitionWriter writer = writeSupport.newPartitionWriter( + sparkConf.getAppId(), shuffleId, mapId)) { + try { + for (int i = 0; i < spills.length; i++) { + spillInputStreams[i] = new NioBufferedFileInputStream( + spills[i].file, + inputBufferSizeInBytes); + } + for (int partition = 0; partition < numPartitions; partition++) { + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); + try { + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + partitionLengths[partition] = writer.appendPartition(partition, partitionInputStream); + } finally { + partitionInputStream.close(); + } + } + } + } + } catch (Exception e) { + try { + writer.abort(); + } catch (Exception e2) { + logger.warn("Failed to close shuffle writer upon aborting.", e2); + } + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (InputStream stream : spillInputStreams) { + Closeables.close(stream, threwException); + } + } + return partitionLengths; + } + + + /** * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. * This is only safe when the IO compression codec and serializer support concatenation of diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 9f1c20ff9e92b..925947a0c6fd2 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -211,13 +211,10 @@ private class ShuffleStatus(numPartitions: Int) { } private[spark] sealed trait MapOutputTrackerMessage -private[spark] case class GetMapOutputStatuses(shuffleId: Int, getBackup: Boolean) - extends MapOutputTrackerMessage -private[spark] case class ReportBackedUpMapOutput( - shuffleId: Int, mapId: Int, backedUpStatus: MapStatus) +private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -private[spark] case object GetBackupShuffleServiceAddresses extends MapOutputTrackerMessage +private[spark] case object GetRemoteShuffleServiceAddresses extends MapOutputTrackerMessage private[spark] sealed trait BackupMessage private[spark] case class HeartbeaterMessage(appId: String) extends BackupMessage @@ -231,29 +228,20 @@ private[spark] class MapOutputTrackerMasterEndpoint( logDebug("init") // force eager creation of logger override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case GetMapOutputStatuses(shuffleId: Int, getBackup: Boolean) => + case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val message = GetMapOutputMessage(shuffleId, context) - if (getBackup) { - tracker.postToBackup[GetMapOutputMessage](message) - } else { - tracker.post[GetMapOutputMessage](message) - } + tracker.post[GetMapOutputMessage](message) - case GetBackupShuffleServiceAddresses => - context.reply(tracker.getBackupShuffleServiceAddresses) + case GetRemoteShuffleServiceAddresses => + context.reply(tracker.getRemoteShuffleServiceAddresses) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") context.reply(true) stop() } - - override def receive(): PartialFunction[Any, Unit] = { - case ReportBackedUpMapOutput(shuffleId, mapId, backedUpStatus) => - tracker.registerBackupMapOutput(shuffleId, mapId, backedUpStatus) - } } /** @@ -303,7 +291,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { - getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1, false) + getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) } /** @@ -316,7 +304,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId( - shuffleId: Int, startPartition: Int, endPartition: Int, getBackup: Boolean) + shuffleId: Int, startPartition: Int, endPartition: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] /** @@ -462,10 +450,6 @@ private[spark] class MapOutputTrackerMaster( shuffleStatuses(shuffleId).addMapOutput(mapId, status) } - def registerBackupMapOutput(shuffleId: Int, mapId: Int, status: MapStatus): Unit = { - backupMaster.foreach(_.registerMapOutput(shuffleId, mapId, status)) - } - /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { shuffleStatuses.get(shuffleId) match { @@ -687,27 +671,21 @@ private[spark] class MapOutputTrackerMaster( } } - def getBackupShuffleServiceAddresses: List[(String, Int)] = + def getRemoteShuffleServiceAddresses: List[(String, Int)] = shuffleServiceAddressProvider.getShuffleServiceAddresses() // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. - override def getMapSizesByExecutorId( - shuffleId: Int, startPartition: Int, endPartition: Int, getBackup: Boolean) + override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { - if (getBackup) { - require(backupMaster.isDefined, "Backup master not defined") - backupMaster.get.getMapSizesByExecutorId(shuffleId, startPartition, endPartition, false) - } else { - logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - shuffleStatuses.get(shuffleId) match { - case Some(shuffleStatus) => - shuffleStatus.withMapStatuses { statuses => - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) - } - case None => - Iterator.empty - } + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.withMapStatuses { statuses => + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } + case None => + Iterator.empty } } @@ -731,32 +709,21 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr val mapStatuses: Map[Int, Array[MapStatus]] = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala - val backupMapStatuses: Map[Int, Array[MapStatus]] = - new ConcurrentHashMap[Int, Array[MapStatus]]().asScala - /** Remembers which map output locations are currently being fetched on an executor. */ private val fetching = new HashSet[Int] - /** Remembers which backup map output locations are currently being fetched on an executor. */ - private val fetchingBackup = new HashSet[Int] - // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByExecutorId( - shuffleId: Int, startPartition: Int, endPartition: Int, getBackup: Boolean) + shuffleId: Int, startPartition: Int, endPartition: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - val statuses = if (getBackup) { - getStatuses(shuffleId, backupMapStatuses, fetchingBackup, getBackup) - } else { - getStatuses(shuffleId, mapStatuses, fetching, getBackup) - } + val statuses = getStatuses(shuffleId, mapStatuses, fetching) try { MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: mapStatuses.clear() - backupMapStatuses.clear() throw e } } @@ -770,8 +737,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private def getStatuses( shuffleId: Int, statusesToInspect: Map[Int, Array[MapStatus]], - statusesBeingFetched: mutable.HashSet[Int], - getBackup: Boolean) + statusesBeingFetched: mutable.HashSet[Int]) : Array[MapStatus] = { val statuses = statusesToInspect.get(shuffleId).orNull if (statuses == null) { @@ -802,7 +768,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { - val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId, getBackup)) + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") statusesToInspect.put(shuffleId, fetchedStatuses) diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index ac59119a36230..df44cc26a8cc2 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -67,7 +67,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockIds: Array[String], - isBackup: Boolean, + isRemote: Boolean, listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit @@ -93,11 +93,11 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockId: String, - isBackup: Boolean, + isRemote: Boolean, tempFileManager: DownloadFileManager): ManagedBuffer = { // A monitor for the thread to wait on. val result = Promise[ManagedBuffer]() - fetchBlocks(host, port, execId, Array(blockId), isBackup, + fetchBlocks(host, port, execId, Array(blockId), isRemote, new BlockFetchingListener { override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { result.failure(exception) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index cc68d825fc169..3089b94576063 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -105,7 +105,7 @@ private[spark] class NettyBlockTransferService( port: Int, execId: String, blockIds: Array[String], - isBackup: Boolean, + isRemote: Boolean, listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index b8c20cdceade9..ab2e0c7fec303 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -20,7 +20,8 @@ package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.shuffle.external.ShuffleReadSupport +import org.apache.spark.storage._ import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -33,6 +34,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( startPartition: Int, endPartition: Int, context: TaskContext, + appId: String, + shuffleReadSupport: ShuffleReadSupport = null, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) @@ -42,20 +45,35 @@ private[spark] class BlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val wrappedStreams = new ShuffleBlockFetcherIterator( - context, - blockManager.shuffleClient, - blockManager, - mapOutputTracker.getMapSizesByExecutorId( - handle.shuffleId, startPartition, endPartition, context.attemptNumber() > 0), - serializerManager.wrapStream, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, - SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), - SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) - + val wrappedStreams = if (shuffleReadSupport != null) { + getBlockIdsGroupedByMapIds(handle.shuffleId, startPartition, endPartition) + .flatMap { case (mapId, blockIds) => + val reader = shuffleReadSupport.newPartitionReader( + appId, handle.shuffleId, mapId) + blockIds.map { + case ShuffleBlockId(_, _, reduceId) => reader.fetchPartition(reduceId) + case ShuffleDataBlockId(_, _, reduceId) => reader.fetchPartition(reduceId) + case invalid => + throw new IllegalArgumentException(s"Invalid block id $invalid") + } + } + } else { + val mapSizesByExecId = + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) + new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapSizesByExecId, + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true), + readMetrics) + } val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream @@ -121,4 +139,25 @@ private[spark] class BlockStoreShuffleReader[K, C]( new InterruptibleIterator[Product2[K, C]](context, resultIter) } } + private def getBlockIdsGroupedByMapIds( + shuffleId: Int, startPartition: Int, endPartition: Int): Iterator[(Int, Seq[BlockId])] = { + mapOutputTracker.getMapSizesByExecutorId(shuffleId, startPartition, endPartition) + .flatMap(_._2) + .map(_._1) + .toStream + .filter { blockId => + blockId match { + case ShuffleBlockId(_, _, _) => true + case ShuffleDataBlockId(_, _, _) => true + case _ => false + } + } + .groupBy { + case ShuffleBlockId(_, mapId, _) => mapId + case ShuffleDataBlockId(_, mapId, _) => mapId + case blockId => + throw new IllegalArgumentException(s"Invalid block id: $blockId") + }.mapValues(_.toSeq) + .iterator + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index b611c9bdcc93b..0a5e12c36c444 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.Utils private[spark] class IndexShuffleBlockResolver( conf: SparkConf, _blockManager: BlockManager = null, - shuffleDataIO: ShuffleDataIO = null) + _shuffleDataIO: ShuffleDataIO = null) extends ShuffleBlockResolver with Logging { @@ -58,14 +58,18 @@ private[spark] class IndexShuffleBlockResolver( private var shuffleReadSupport: ShuffleReadSupport = _ - private var isExternalFileSystem = false + private var isRemote_ = false if (shuffleDataIO != null) { - shuffleWriteSupport = shuffleDataIO.writeSupport() - shuffleReadSupport = shuffleDataIO.readSupport() - isExternalFileSystem = true + shuffleWriteSupport = _shuffleDataIO.writeSupport() + shuffleReadSupport = _shuffleDataIO.readSupport() + isRemote_ = true } + def shuffleDataIO(): ShuffleDataIO = _shuffleDataIO + def isRemote(): Boolean = isRemote_ + + def getDataFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } @@ -234,10 +238,10 @@ private[spark] class IndexShuffleBlockResolver( offset, nextOffset - offset) } finally { - if (isExternalFileSystem) { + if (isRemote()) { val writer = shuffleWriteSupport.newPartitionWriter(appId, blockId.shuffleId, blockId.mapId) try { - writer.appendPartition(blockId.reduceId, in) + writer.appendIndexFile(blockId.reduceId, in) } catch { case e: Exception => writer.abort(e) diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/BackingUpShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/external/BackingUpShuffleWriter.scala deleted file mode 100644 index c69e5905604eb..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/external/BackingUpShuffleWriter.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.external - -import java.io.File -import java.nio.ByteBuffer -import java.util.concurrent.ExecutorService - -import com.google.common.util.concurrent.SettableFuture -import scala.concurrent.{ExecutionContext, Future} -import scala.util.{Failure, Success} - -import org.apache.spark.{MapOutputTracker, ReportBackedUpMapOutput} -import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} -import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, UploadShuffleFileStream, UploadShuffleIndexFileStream} -import org.apache.spark.network.util.TransportConf -import org.apache.spark.scheduler.{MapStatus, RelocatedMapStatus} -import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter} -import org.apache.spark.storage.BlockManagerId - -class BackingUpShuffleWriter[K, V]( - shuffleBlockResolver: IndexShuffleBlockResolver, - delegateWriter: ShuffleWriter[K, V], - backupShuffleServiceClient: TransportClient, - transportConf: TransportConf, - mapOutputTracker: MapOutputTracker, - backupExecutor: ExecutorService, - backupHost: String, - backupPort: Int, - appId: String, - execId: String, - shuffleId: Int, - mapId: Int, - backupShuffleDataIO: ShuffleDataIO = null) - extends ShuffleWriter[K, V] with Logging { - - private implicit val backupExecutorContext = ExecutionContext.fromExecutorService(backupExecutor) - - private val writeSupport = backupShuffleDataIO.writeSupport() - /** Write a sequence of records to this task's output */ - override def write(records: Iterator[Product2[K, V]]): Unit = { - delegateWriter.write(records) - } - - /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { - val delegateMapStatus = delegateWriter.stop(success) - delegateMapStatus.foreach { _ => - val outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId) - val indexFile = shuffleBlockResolver.getIndexFile(shuffleId, mapId) - if (outputFile.isFile && indexFile.isFile) { - val uploadBackupFileRequest = new UploadShuffleFileStream( - appId, execId, shuffleId, mapId) - val uploadIndexFileRequest = new UploadShuffleIndexFileStream( - appId, execId, shuffleId, mapId) - - val backupFileTask: Future[Unit] = Future { - backupFile(outputFile, uploadBackupFileRequest) - backupFile(indexFile, uploadIndexFileRequest) - } - - backupFileTask.onComplete { - case Success(_) => - val backedUpMapStatus = RelocatedMapStatus( - delegateMapStatus.get, - BlockManagerId(execId, backupHost, backupPort, None, isBackup = true)) - mapOutputTracker.trackerEndpoint.send( - ReportBackedUpMapOutput(shuffleId, mapId, backedUpMapStatus)) - case Failure(_) => logError("An error has occured in backing up") - } - } - } - delegateMapStatus - } - - private def backupFile( - fileToBackUp: File, - backupFileRequest: BlockTransferMessage) { - backupShuffleDataIO.writeSupport() - val dataFileBuffer = new FileSegmentManagedBuffer( - transportConf, fileToBackUp, 0, fileToBackUp.length()) - val uploadBackupRequestBuffer = new NioManagedBuffer(backupFileRequest.toByteBuffer) - val awaitCompletion = SettableFuture.create[Boolean] - backupShuffleServiceClient.uploadStream( - uploadBackupRequestBuffer, dataFileBuffer, new RpcResponseCallback { - override def onSuccess(response: ByteBuffer): Unit = { - logInfo("Successfully backed up shuffle map data file" + - s" (shuffle id: $shuffleId, map id: $mapId, executor id: $execId)") - awaitCompletion.set(true) - } - - /** Exception either propagated from server or raised on client side. */ - override def onFailure(e: Throwable): Unit = { - logError("Failed to back up shuffle map data file" + - s" (shuffle id: $shuffleId, map id: $mapId, executor id: $execId)") - awaitCompletion.setException(e) - } - }) - awaitCompletion.get() - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ExternalFallbackShuffleClient.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ExternalRemoteShuffleClient.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/shuffle/external/ExternalFallbackShuffleClient.scala rename to core/src/main/scala/org/apache/spark/shuffle/external/ExternalRemoteShuffleClient.scala index 373b8a3060d33..c6988901a2259 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/external/ExternalFallbackShuffleClient.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ExternalRemoteShuffleClient.scala @@ -20,7 +20,7 @@ package org.apache.spark.shuffle.external import org.apache.spark.network.BlockTransferService import org.apache.spark.network.shuffle._ -private[spark] class ExternalFallbackShuffleClient( +private[spark] class ExternalRemoteShuffleClient( externalShuffleClient: ExternalShuffleClient, baseBlockTransferService: BlockTransferService) extends ShuffleClient { @@ -34,21 +34,21 @@ private[spark] class ExternalFallbackShuffleClient( port: Int, execId: String, blockIds: Array[String], - isBackup: Boolean, + isRemote: Boolean, listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { - if (isBackup) { + if (isRemote) { externalShuffleClient.fetchBlocks( host, port, execId, blockIds, - isBackup, + isRemote, listener, downloadFileManager) } else { baseBlockTransferService.fetchBlocks( - host, port, execId, blockIds, isBackup, listener, downloadFileManager) + host, port, execId, blockIds, isRemote, listener, downloadFileManager) } } @@ -57,11 +57,13 @@ private[spark] class ExternalFallbackShuffleClient( externalShuffleClient.close() } - def registerWithShuffleServerForBackups( + def registerWithRemoteShuffleServer( + driverHostPort: String, host: String, port: Int, execId: String, shuffleManager: String) : Unit = { - externalShuffleClient.registerWithShuffleServerForBackups(host, port, execId, shuffleManager) + externalShuffleClient.registerWithRemoteShuffleServer( + driverHostPort, host, port, execId, shuffleManager) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala index 0250eb4659dde..2479f98f04c54 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala @@ -17,9 +17,10 @@ package org.apache.spark.shuffle.external -import java.io.{Closeable, InputStream} +import java.io.{Closeable, InputStream, OutputStream} private[spark] trait ShufflePartitionWriter extends Closeable { - def appendPartition(partitionId: Long, partitionInput: InputStream): Unit + def appendPartition(partitionId: Long, partitionOutput: OutputStream): Unit + def appendIndexFile(partitionId: Long, indexInput: InputStream): Unit def abort(exception: Throwable): Unit } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 7c7f0d89958ae..cd203b5bf38e9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -27,7 +27,7 @@ import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.server.NoOpRpcHandler import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.external.BackingUpShuffleWriter +import org.apache.spark.shuffle.external.ShuffleDataIO import org.apache.spark.util.ThreadUtils /** @@ -91,10 +91,11 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager private val backupShuffleTransportConf = SparkTransportConf.fromSparkConf( conf, "shuffle", 2) - private lazy val backupShuffleTransportClients = SparkEnv + private lazy val remoteShuffleTransportClients = SparkEnv .get + .blockManager. .blockManager - .getBackupShuffleServiceAddresses() + .getRemoteShuffleServiceAddresses() .map(address => { val addressAsUri = URI.create(s"spark://${address._1}:${address._2}") val transportContext = new TransportContext( @@ -104,6 +105,9 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager (address, addressAsUri, transportContext.createClientFactory()) }) + // TODO: Fill out defaults + private val shuffleDataIO: ShuffleDataIO = null + /** * Obtains a [[ShuffleHandle]] to pass to tasks. */ @@ -139,7 +143,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + startPartition, + endPartition, + context, + conf.getAppId, + shuffleDataIO.readSupport()) } /** Get a writer for a given partition. Called on executors by map tasks. */ @@ -150,7 +159,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager numMapsForShuffle.putIfAbsent( handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) val env = SparkEnv.get - val baseWriter = handle match { + handle match { case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( env.blockManager, @@ -171,27 +180,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) } - Random.shuffle(backupShuffleTransportClients) - .headOption - .map(addressAndClient => { - val transportClient = - addressAndClient._3.createClient( - addressAndClient._2.getHost, addressAndClient._2.getPort) - new BackingUpShuffleWriter( - new IndexShuffleBlockResolver(conf, shuffleDataIO = new DefaultShuffleDataIO()), - baseWriter, - transportClient, - backupShuffleTransportConf, - env.mapOutputTracker, - ThreadUtils.newDaemonCachedThreadPool("backup-shuffle-files"), - addressAndClient._1._1, - addressAndClient._1._2, - new DefaultShuffleDataIO(), - conf.getAppId, - env.blockManager.blockManagerId.executorId, - handle.shuffleId, - mapId) - }).getOrElse(baseWriter) } /** Remove a shuffle's metadata from the ShuffleManager. */ @@ -207,7 +195,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager /** Shut down this ShuffleManager. */ override def stop(): Unit = { shuffleBlockResolver.stop() - backupShuffleTransportClients.foreach(_._3.close()) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 274399b9cc1f3..49972c6540d59 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.sort +import java.io.File + import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatus @@ -64,11 +66,15 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). + var tmp: File = null val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) - val tmp = Utils.tempFileWith(output) + tmp = Utils.tempFileWith(output) try { val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, tmp) + val partitionLengths = sorter.writePartitionedFile( + blockId, + tmp, + shuffleBlockResolver.shuffleDataIO) shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } finally { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index b4a214de81fb5..cef9046207ceb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,14 +18,13 @@ package org.apache.spark.storage import java.io._ -import java.lang.ref.{WeakReference, ReferenceQueue => JReferenceQueue} +import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference} import java.nio.ByteBuffer import java.nio.channels.Channels import java.util.Collections import java.util.concurrent.ConcurrentHashMap import com.codahale.metrics.{MetricRegistry, MetricSet} - import scala.collection.mutable import scala.collection.mutable.HashMap import scala.concurrent.{ExecutionContext, Future} @@ -33,9 +32,10 @@ import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal + import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} -import org.apache.spark.internal.{Logging, config} +import org.apache.spark.internal.{config, Logging} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source import org.apache.spark.network._ @@ -49,7 +49,7 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager -import org.apache.spark.shuffle.external.{ExternalFallbackShuffleClient, ShuffleDataIO} +import org.apache.spark.shuffle.external.{ExternalRemoteShuffleClient, ShuffleDataIO} import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ @@ -63,7 +63,7 @@ private[spark] class BlockResult( /** * Abstracts away how blocks are stored and provides different ways to read the underlying block - * data. Callers should call [[dispose()]] when they're done with the block. + * data. Callers should call [[ dispose() ]] when they're done with the block. */ private[spark] trait BlockData { @@ -194,7 +194,7 @@ private[spark] class BlockManager( // service, or just our own Executor's BlockManager. private[spark] var shuffleServerId: BlockManagerId = _ - private var backupShuffleServiceAddresses: List[(String, Int)] = _ + private var remoteShuffleServiceAddresses: List[(String, Int)] = _ // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. @@ -260,36 +260,36 @@ private[spark] class BlockManager( blockManagerId = if (idFromMaster != null) idFromMaster else id - backupShuffleServiceAddresses = if (blockManagerId.isDriver) { + remoteShuffleServiceAddresses = if (blockManagerId.isDriver) { List.empty[(String, Int)] } else { Random.shuffle(mapOutputTracker .trackerEndpoint - .askSync[List[(String, Int)]](GetBackupShuffleServiceAddresses)) + .askSync[List[(String, Int)]](GetRemoteShuffleServiceAddresses)) .take(3) } - if (backupShuffleServiceAddresses.nonEmpty) { - require(!externalShuffleServiceEnabled, "Cannot use external shuffle service with backup" + + if (remoteShuffleServiceAddresses.nonEmpty) { + require(!externalShuffleServiceEnabled, "Cannot use external shuffle service with remote" + " shuffle services.") } shuffleClient = if (externalShuffleServiceEnabled) { - require(backupShuffleServiceAddresses.isEmpty, - "Cannot use the external shuffle service while using backup shuffle services.") + require(remoteShuffleServiceAddresses.isEmpty, + "Cannot use the external shuffle service while using remote shuffle services.") val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) - } else if (backupShuffleServiceAddresses.nonEmpty) { - logInfo("Using BackupShuffleService") + } else if (remoteShuffleServiceAddresses.nonEmpty) { + logInfo("Using RemoteShuffleServices") val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) val externalShuffleClient = new ExternalShuffleClient( transConf, securityManager, securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) - new ExternalFallbackShuffleClient(externalShuffleClient, blockTransferService) + new ExternalRemoteShuffleClient(externalShuffleClient, blockTransferService) } else blockTransferService shuffleClient.init(appId) @@ -314,8 +314,8 @@ private[spark] class BlockManager( registerWithExternalShuffleServer() } - backupShuffleServiceAddresses.foreach(address => { - registerWithBackupShuffleServer( + remoteShuffleServiceAddresses.foreach(address => { + registerWithRemoteShuffleServer( address._1, address._2, appId) @@ -337,8 +337,6 @@ private[spark] class BlockManager( private def registerWithExternalShuffleServer() { logInfo("Registering executor with local external shuffle service.") val shuffleConfig = new ExecutorShuffleInfo( - backupShuffleServiceAddresses.map(_._1).toArray, - backupShuffleServiceAddresses.map(_._2.toString).toArray, diskBlockManager.localDirs.map(_.toString), diskBlockManager.subDirsPerLocalDir, shuffleManager.getClass.getName) @@ -364,7 +362,7 @@ private[spark] class BlockManager( } } - private def registerWithBackupShuffleServer( + private def registerWithRemoteShuffleServer( shuffleServerHost: String, shuffleServerPort: Int, appId: String) @@ -375,8 +373,9 @@ private[spark] class BlockManager( try { // Synchronous and will throw an exception if we cannot connect. shuffleClient - .asInstanceOf[ExternalFallbackShuffleClient] - .registerWithShuffleServerForBackups( + .asInstanceOf[ExternalRemoteShuffleClient] + .registerWithRemoteShuffleServer( + master.driverEndpoint.address.hostPort, shuffleServerHost, shuffleServerPort, shuffleServerId.executorId, @@ -843,7 +842,7 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString, loc.isBackup, tempFileManager) + loc.host, loc.port, loc.executorId, blockId.toString, loc.isRemote, tempFileManager) } catch { case NonFatal(e) => runningFailureCount += 1 @@ -1691,7 +1690,7 @@ private[spark] class BlockManager( } } - def getBackupShuffleServiceAddresses(): List[(String, Int)] = backupShuffleServiceAddresses + def getRemoteShuffleServiceAddresses(): List[(String, Int)] = remoteShuffleServiceAddresses def releaseLockAndDispose( blockId: BlockId, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 3b8a03bb23713..1b92b41fedda5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -40,7 +40,7 @@ class BlockManagerId private ( private var host_ : String, private var port_ : Int, private var topologyInfo_ : Option[String], - private var isBackup_ : Boolean = false) + private var isRemote_ : Boolean = false) extends Externalizable { private def this() = this(null, null, 0, None) // For deserialization only @@ -63,7 +63,7 @@ class BlockManagerId private ( def port: Int = port_ - def isBackup: Boolean = isBackup_ + def isRemote: Boolean = isRemote_ def topologyInfo: Option[String] = topologyInfo_ @@ -77,7 +77,7 @@ class BlockManagerId private ( out.writeUTF(executorId_) out.writeUTF(host_) out.writeInt(port_) - out.writeBoolean(isBackup) + out.writeBoolean(isRemote) out.writeBoolean(topologyInfo_.isDefined) // we only write topologyInfo if we have it topologyInfo.foreach(out.writeUTF(_)) @@ -87,7 +87,7 @@ class BlockManagerId private ( executorId_ = in.readUTF() host_ = in.readUTF() port_ = in.readInt() - isBackup_ = in.readBoolean() + isRemote_ = in.readBoolean() val isTopologyInfoAvailable = in.readBoolean() topologyInfo_ = if (isTopologyInfoAvailable) Option(in.readUTF()) else None } @@ -131,9 +131,9 @@ private[spark] object BlockManagerId { host: String, port: Int, topologyInfo: Option[String] = None, - isBackup: Boolean = false): BlockManagerId = + isRemote: Boolean = false): BlockManagerId = getCachedBlockManagerId(new BlockManagerId( - execId, host, port, topologyInfo, isBackup)) + execId, host, port, topologyInfo, isRemote)) def apply(in: ObjectInput): BlockManagerId = { val obj = new BlockManagerId() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index db1a4a8c7d2f1..547d0219249fb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -62,7 +62,7 @@ private[spark] class BlockStore( } blockMapper match { case _: RemoteBlockManager => - throw new IllegalAccessError("Remote Block Mapper does not support this writing feature") + throw new IllegalAccessError("Remote BlockMapper does not support this writing feature") case d: DiskBlockManager => logDebug(s"Attempting to put block $blockId") val startTime = System.currentTimeMillis diff --git a/core/src/main/scala/org/apache/spark/storage/RemoteBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/RemoteBlockObjectWriter.scala new file mode 100644 index 0000000000000..001a2f59b848b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/RemoteBlockObjectWriter.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io._ +import java.nio.channels.FileChannel + +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.external.ShuffleWriteSupport +import org.apache.spark.util.Utils + +/** + * A class for writing JVM objects directly to a file on disk. This class allows data to be appended + * to an existing block. For efficiency, it retains the underlying file channel across + * multiple commits. This channel is kept open until close() is called. In case of faults, + * callers should instead close with revertPartialWritesAndClose() to atomically revert the + * uncommitted partial writes. + * + * This class does not support concurrent writes. Also, once the writer has been opened it cannot be + * reopened again. + */ +private[spark] class RemoteBlockObjectWriter( + shuffleWriteSupport: ShuffleWriteSupport, + serializerManager: SerializerManager, + serializerInstance: SerializerInstance, + bufferSize: Int, + syncWrites: Boolean, + // These write metrics concurrently shared with other active BlockObjectWriters who + // are themselves performing writes. All updates must be relative. + writeMetrics: ShuffleWriteMetrics, + val blockId: BlockId = null) extends Logging { + + private var byteArrayOutput: ByteArrayOutputStream = null + private var bs: OutputStream = null + private var objectOutputStream: ObjectOutputStream = null + private var ts: TimeTrackingOutputStream = null + private var bufferedOS: BufferedOutputStream = null + private var serializationStream: SerializationStream = null + private var mcOS: ManualCloseOutputStream = null + private var initialized = false + private var streamOpen = false + private var hasBeenClosed = false + + private def initialize(): Unit = { + byteArrayOutput = new ByteArrayOutputStream() + objectOutputStream = new ObjectOutputStream(byteArrayOutput) + ts = new TimeTrackingOutputStream(writeMetrics, objectOutputStream) + mcOS: ManualCloseOutputStream = + } + + /** + * Guards against close calls, e.g. from a wrapping stream. + * Call manualClose to close the stream that was extended by this trait. + * Commit uses this trait to close object streams without paying the + * cost of closing and opening the underlying file. + */ + private trait ManualCloseOutputStream extends OutputStream { + abstract override def close(): Unit = { + flush() + } + + def manualClose(): Unit = { + super.close() + } + } + + /** + * Keep track of number of records written and also use this to periodically + * output bytes written since the latter is expensive to do for each record. + * And we reset it after every commitAndGet called. + */ + private var numRecordsWritten = 0 + + def open(): RemoteBlockObjectWriter = { + if (hasBeenClosed) { + throw new IllegalStateException("Writer already closed. Cannot be reopened.") + } + if (!initialized) { + initialize() + initialized = true + } + + bs = serializerManager.wrapStream(blockId, new BufferedOutputStream(ts, bufferSize)) + serializationStream = serializerInstance.serializeStream(bs) + streamOpen = true + this + } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f7b7b68e53a7e..d91d86a12c6f1 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -259,7 +259,7 @@ final class ShuffleBlockFetcherIterator( address.port, address.executorId, blockIds.toArray, - address.isBackup, + address.isRemote, blockFetchingListener, this) } else { shuffleClient.fetchBlocks( @@ -267,7 +267,7 @@ final class ShuffleBlockFetcherIterator( address.port, address.executorId, blockIds.toArray, - address.isBackup, + address.isRemote, blockFetchingListener, null) } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b159200d79222..567fb935100c0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -22,13 +22,12 @@ import java.util.Comparator import scala.collection.mutable import scala.collection.mutable.ArrayBuffer - import com.google.common.io.ByteStreams - import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.serializer._ +import org.apache.spark.shuffle.external.ShuffleDataIO import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** @@ -682,13 +681,13 @@ private[spark] class ExternalSorter[K, V, C]( */ def writePartitionedFile( blockId: BlockId, - outputFile: File): Array[Long] = { + outputFile: File, + _shuffleDataIO: ShuffleDataIO = null): Array[Long] = { // Track location of each range in the output file val lengths = new Array[Long](numPartitions) val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, context.taskMetrics().shuffleWriteMetrics) - if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 5232c2bd8d6f6..f061eb57bf6c2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -57,6 +57,8 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { cur = if (it.hasNext) it.next() else null } + + def hasNext(): Boolean = cur != null def nextPartition(): Int = cur._1._1 diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index a96ea420e7f0a..462769188a13f 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -194,7 +194,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex cmId.port, cmId.executorId, blockId.toString, - cmId.isBackup, + cmId.isRemote, null) val deserialized = serializerManager.dataDeserializeStream(blockId, new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 55c66b8e04200..5cd1954531e4e 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1446,7 +1446,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockIds: Array[String], - isBackup: Boolean, + isRemote: Boolean, listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) @@ -1475,14 +1475,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockId: String, - isBackup: Boolean, + isRemote: Boolean, tempFileManager: DownloadFileManager): ManagedBuffer = { numCalls += 1 this.tempFileManager = tempFileManager if (numCalls <= maxFailures) { throw new RuntimeException("Failing block fetch in the mock block transfer service") } - super.fetchBlockSync(host, port, execId, blockId, isBackup, tempFileManager) + super.fetchBlockSync(host, port, execId, blockId, isRemote, tempFileManager) } } }