From 1a02af231f2bfe60a43b4da9f362638ff97254f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BB=B2=E7=94=AB?= Date: Thu, 8 May 2025 22:27:30 +0800 Subject: [PATCH 1/3] [WIP] dynamically adjust partition write parallelism - client side --- .../celeborn/client/LocationManager.java | 342 +++++++++++ .../apache/celeborn/client/ReviveManager.java | 176 +++--- .../apache/celeborn/client/ShuffleClient.java | 8 + .../celeborn/client/ShuffleClientImpl.java | 532 +++++++++--------- .../celeborn/client/write/DataPushQueue.java | 13 +- .../client/RequestLocationCallContext.scala | 5 +- .../common/protocol/ReviveRequest.java | 50 +- .../common/protocol/message/StatusCode.java | 3 +- common/src/main/proto/TransportMessages.proto | 10 +- .../protocol/message/ControlMessages.scala | 43 +- 10 files changed, 831 insertions(+), 351 deletions(-) create mode 100644 client/src/main/java/org/apache/celeborn/client/LocationManager.java diff --git a/client/src/main/java/org/apache/celeborn/client/LocationManager.java b/client/src/main/java/org/apache/celeborn/client/LocationManager.java new file mode 100644 index 00000000000..d741dbe1514 --- /dev/null +++ b/client/src/main/java/org/apache/celeborn/client/LocationManager.java @@ -0,0 +1,342 @@ +package org.apache.celeborn.client; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.common.protocol.ReviveRequest; +import org.apache.celeborn.common.protocol.message.StatusCode; +import org.apache.celeborn.common.util.JavaUtils; +import org.apache.celeborn.common.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +public class LocationManager { + + private static final Logger logger = LoggerFactory.getLogger(LocationManager.class); + + private class PartitionLocationList { + List locations = new ArrayList<>(); + Set locationSet = new HashSet<>(); + // epoch id -> status + Map locationStatusCode = new HashMap<>(); + ReviveRequest latestReviveRequest = null; + int[] index = null; + + int used = 0; + int size = 0; + int maxEpoch = -1; + + final int shuffleId; + final int partitionId; + public PartitionLocationList(int shuffleId, int partitionId) { + this.shuffleId = shuffleId; + this.partitionId = partitionId; + } + + ReadWriteLock lock = new ReentrantReadWriteLock(); + Lock readLock = lock.readLock(); + Lock writeLock = lock.writeLock(); + + private void update(List newLocs) { + if (newLocs.isEmpty()) { + return; + } + newLocs.sort(Comparator.comparing(PartitionLocation::getEpoch)); + int newMaxEpoch = newLocs.get(newLocs.size() - 1).getEpoch(); + try { + writeLock.lock(); + if (newMaxEpoch <= maxEpoch) { + return; + } + int newSize = newLocs.size() + size - used; + ArrayList newLocations = new ArrayList<>(newSize); + for (PartitionLocation oldLoc : locations) { + if (locationStatusCode.remove(oldLoc.getEpoch()) != null) { + used--; + } else { + newLocations.add(oldLoc); + } + } + for (PartitionLocation l : newLocs) { + if (l.getEpoch() >= maxEpoch) { + newLocations.add(l); + } + } + size = newLocations.size(); + index = new int[size]; + for (int i = 0; i < size; i++) { + index[i] = i; + } + locations = newLocations; + locationSet.clear(); + locations.forEach(l -> locationSet.add(l.getEpoch())); + maxEpoch = Math.max(maxEpoch, newMaxEpoch); + logger.info("Location updated for shuffleId {}, partitionId {}, new locations: {}, maxEpoch: {}", shuffleId, partitionId, locationSet, maxEpoch); + if (latestReviveRequest != null && latestReviveRequest.clientMaxEpoch < maxEpoch) { + logger.debug("outdated latestReviveRequest {}", latestReviveRequest); + this.latestReviveRequest = null; + } + } finally { + writeLock.unlock(); + } + } + + // return partitionLocation for specified mapId + // if allowSoftSplit = true, soft split location can be returned + // if liveOnly = false, non-living (soft split/hard split/push fail) location can be returned + private PartitionLocation nextLoc(int mapId, boolean allowSoftSplit, boolean liveOnly) { + try { + readLock.lock(); + int pos = mapId % size; + int idx = index[pos]; + while (locationStatusCode.get(locations.get(idx).getEpoch()) != null) { + if (allowSoftSplit && locationStatusCode.get(locations.get(idx).getEpoch()) == StatusCode.SOFT_SPLIT) { + break; + } + idx = (idx + 1) % size; + // all locations are checked + if (idx == index[pos]) { + break; + } + } + if (idx != index[pos]) { + index[pos] = idx; + } + if (locationStatusCode.get(locations.get(idx).getEpoch()) != null) { + if ((allowSoftSplit && locationStatusCode.get(locations.get(idx).getEpoch()) == StatusCode.SOFT_SPLIT) || + !liveOnly) { + return locations.get(idx); + } + return null; + } else { + return locations.get(idx); + } + } finally { + readLock.unlock(); + } + } + + private void reviveBatch(int shuffleId, int partitionId, int mapId, int attemptId) { + ReviveRequest reviveRequest = null; + try { + writeLock.lock(); + reviveRequest = new ReviveRequest(shuffleId, mapId, attemptId, partitionId, null, StatusCode.URGENT_REVIVE, maxEpoch, true); + this.latestReviveRequest = reviveRequest; + logger.debug("in reviveBatch latestReviveRequest = {}", reviveRequest); + } finally { + writeLock.unlock(); + } + reviveManager.addRequest(reviveRequest); + } + + private void reportUnusableLocation(int shuffleId, int mapId, int attemptId, PartitionLocation loc, StatusCode reportedStatus) { + ReviveRequest reviveRequest = null; + try { + writeLock.lock(); + if (!locationSet.contains(loc.getEpoch())) { + return; + } + + StatusCode currentStatus = locationStatusCode.get(loc.getEpoch()); + if (currentStatus != reportedStatus) { + // allow normal/soft split to transition to hard split/push failure + if (currentStatus == null || currentStatus == StatusCode.SOFT_SPLIT) { + locationStatusCode.put(loc.getEpoch(), reportedStatus); + } + if (currentStatus == null) { + used++; + } + boolean urgent = ((used == size) && !hasActiveReviveRequest()); + reviveRequest = new ReviveRequest(shuffleId, mapId, attemptId, loc.getId(), loc, reportedStatus, maxEpoch, urgent); + if (urgent) { + this.latestReviveRequest = reviveRequest; + logger.debug("in reportUnusableLocation latestReviveRequest = {}", reviveRequest); + } + } + + } finally { + writeLock.unlock(); + } + if (reviveRequest != null) { + logger.info("Reported worker {}, partitionId {}, epoch {}, shuffle {} map {} attempt {} is unusable, status: {}, urgent: {}", + loc.hostAndPushPort(), loc.getId(), loc.getEpoch(), shuffleId, mapId, attemptId, reportedStatus.name(), reviveRequest.urgent); + reviveManager.addRequest(reviveRequest); + } + } + + private boolean newerPartitionLocationExists(int epoch) { + try { + readLock.lock(); + for (PartitionLocation loc : locations) { + if (locationStatusCode.get(loc.getEpoch()) != null && loc.getEpoch() > epoch) { + return true; + } + } + return false; + } finally { + readLock.unlock(); + } + } + + private boolean locationExists(int epoch) { + try { + readLock.lock(); + return locationSet.contains(epoch); + } finally { + readLock.unlock(); + } + } + + public boolean hasActiveReviveRequest() { + try { + readLock.lock(); + return latestReviveRequest != null && latestReviveRequest.reviveStatus == StatusCode.REVIVE_INITIALIZED.getValue(); + } finally { + readLock.unlock(); + } + } + + public StatusCode getLatestReviveStatus() { + try { + readLock.lock(); + if (latestReviveRequest == null) { + return StatusCode.REVIVE_INITIALIZED; + } else { + return StatusCode.fromValue(latestReviveRequest.reviveStatus); + } + } finally { + readLock.unlock(); + } + } + } + + final Map> reducePartitionMap = + JavaUtils.newConcurrentHashMap(); + + final ReviveManager reviveManager; + + final ShuffleClientImpl shuffleClient; + + public LocationManager(ShuffleClientImpl shuffleClient, ReviveManager reviveManager) { + this.shuffleClient = shuffleClient; + this.reviveManager = reviveManager; + } + + public void registerShuffleLocs(int shuffleId, ConcurrentHashMap> map) { + reducePartitionMap.computeIfAbsent(shuffleId, (id) -> { + ConcurrentHashMap locationMap = JavaUtils.newConcurrentHashMap(); + for (Map.Entry> e : map.entrySet()) { + int partitionId = e.getKey(); + List locs = e.getValue(); + PartitionLocationList list = new PartitionLocationList(shuffleId, partitionId); + list.update(locs); + locationMap.put(partitionId, list); + logger.debug("in registerShuffleLocs, shuffleId {}, partitionId {}", id, partitionId); + } + return locationMap; + }); + } + + public boolean registered(int shuffleId) { + return reducePartitionMap.containsKey(shuffleId); + } + + public boolean exists(int shuffleId, int partitionId) { + if (!registered(shuffleId)) { + throw new UnsupportedOperationException("unexpected! must ensure shuffle registered before checking partition exists "); + } + return reducePartitionMap.get(shuffleId).containsKey(partitionId); + } + + public StatusCode getReviveStatus(int shuffleId, int partitionId) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + return locationList.getLatestReviveStatus(); + } + + public PartitionLocation getLocationOrReviveAsync(int shuffleId, int partitionId, int mapId, int attemptId, boolean doRevive, boolean liveOnly) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + // firstly, try to find a live partition location + PartitionLocation loc = locationList.nextLoc(mapId, false, true); + if (loc == null) { + if (doRevive && !locationList.hasActiveReviveRequest()) { + locationList.reviveBatch(shuffleId, partitionId, mapId, attemptId); + } else if (doRevive && locationList.hasActiveReviveRequest()) { + logger.debug("in getLocationOrReviveAsync, do nothing, current latestReviveRequest is {}", locationList.latestReviveRequest); + } + // can't get a live partition location, then try to find a location in soft split status + // if liveOnly = false, hard split/push fail location can be returned + loc = locationList.nextLoc(mapId, true, liveOnly); + } + return loc; + } + + public boolean reviveSync(int shuffleId, int partitionId, int mapId, int attemptId, StatusCode cause) { + Set mapIds = new HashSet<>(); + mapIds.add(mapId); + List requests = new ArrayList<>(); + ReviveRequest request = new ReviveRequest(shuffleId, mapId, attemptId, partitionId, null, cause, 0, true); + requests.add(request); + Map results = shuffleClient.reviveBatch(shuffleId, mapIds, requests, true); + + if (shuffleClient.mapperEnded(shuffleId, mapId)) { + logger.debug( + "Revive success, but the mapper ended for shuffle {} map {} attempt {} partition {}, just return true(Assume revive successfully).", + shuffleId, + mapId, + attemptId, + partitionId); + return true; + } else { + return results != null + && results.containsKey(partitionId) + && results.get(partitionId) == StatusCode.SUCCESS.getValue(); + } + } + + public void updateLocation(int shuffleId, int partitionId, List newLocations) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + locationList.update(newLocations); + } + + public void reportUnusableLocation(int shuffleId, int mapId, int attemptId, PartitionLocation reportedPartition, StatusCode reportedStatus) { + int partitionId = reportedPartition.getId(); + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + locationList.reportUnusableLocation(shuffleId, mapId, attemptId, reportedPartition, reportedStatus); + } + + public boolean newerPartitionLocationExists(int shuffleId, int partitionId, int epoch) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + return locationList.newerPartitionLocationExists(epoch); + } + + public boolean locationExists(int shuffleId, int partitionId, int epoch) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + return locationList.locationExists(epoch); + } + + public boolean hasActiveReviveRequest(int shuffleId, int partitionId) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + return locationList.hasActiveReviveRequest(); + } + + public void removeShuffle(int shuffleId) { + reducePartitionMap.remove(shuffleId); + } + + @VisibleForTesting + public StatusCode getLocationStatus(int shuffleId, int partitionId, int epochId) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + return locationList.locationStatusCode.getOrDefault(epochId, StatusCode.SUCCESS); + } +} diff --git a/client/src/main/java/org/apache/celeborn/client/ReviveManager.java b/client/src/main/java/org/apache/celeborn/client/ReviveManager.java index 875952f0d9c..888e4d12794 100644 --- a/client/src/main/java/org/apache/celeborn/client/ReviveManager.java +++ b/client/src/main/java/org/apache/celeborn/client/ReviveManager.java @@ -17,16 +17,22 @@ package org.apache.celeborn.client; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.celeborn.common.CelebornConf; -import org.apache.celeborn.common.protocol.PartitionLocation; import org.apache.celeborn.common.protocol.ReviveRequest; import org.apache.celeborn.common.protocol.message.StatusCode; import org.apache.celeborn.common.util.ThreadUtils; @@ -35,84 +41,119 @@ class ReviveManager { private static final Logger logger = LoggerFactory.getLogger(ReviveManager.class); LinkedBlockingQueue requestQueue = new LinkedBlockingQueue<>(); + private final long interval; private final int batchSize; ShuffleClientImpl shuffleClient; - private final ScheduledExecutorService batchReviveRequestScheduler = - ThreadUtils.newDaemonSingleThreadScheduledExecutor( - "celeborn-client-lifecycle-manager-batch-revive-scheduler"); + private ScheduledExecutorService batchReviveRequestScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("batch-revive-scheduler"); + private ThreadPoolExecutor batchReviveRequestHandler = + ThreadUtils.newDaemonFixedThreadPool(2, "batch-revive-handler"); + private ThreadPoolExecutor batchReportRequestHandler = + ThreadUtils.newDaemonFixedThreadPool(2, "batch-report-handler"); public ReviveManager(ShuffleClientImpl shuffleClient, CelebornConf conf) { this.shuffleClient = shuffleClient; + this.interval = conf.clientPushReviveInterval(); this.batchSize = conf.clientPushReviveBatchSize(); - long interval = conf.clientPushReviveInterval(); batchReviveRequestScheduler.scheduleWithFixedDelay( - () -> { - Map> shuffleMap = new HashMap<>(); - do { - ArrayList batchRequests = new ArrayList<>(); - requestQueue.drainTo(batchRequests, batchSize); - for (ReviveRequest req : batchRequests) { - Set set = - shuffleMap.computeIfAbsent(req.shuffleId, id -> new HashSet<>()); - set.add(req); - } - for (Map.Entry> shuffleEntry : shuffleMap.entrySet()) { - // Call reviveBatch for requests in the same (appId, shuffleId) - int shuffleId = shuffleEntry.getKey(); - Set requests = shuffleEntry.getValue(); - Set mapIds = new HashSet<>(); - ArrayList filteredRequests = new ArrayList<>(); - Map requestsToSend = new HashMap<>(); - - Map partitionMap = - shuffleClient.reducePartitionMap.get(shuffleId); - // Insert request that is not MapperEnded and with the max epoch - // into requestsToSend - Iterator iter = requests.iterator(); - while (iter.hasNext()) { - ReviveRequest req = iter.next(); - if (shuffleClient.newerPartitionLocationExists( - partitionMap, req.partitionId, req.epoch, false) - || shuffleClient.mapperEnded(shuffleId, req.mapId)) { - req.reviveStatus = StatusCode.SUCCESS.getValue(); - } else { - filteredRequests.add(req); - mapIds.add(req.mapId); - if (!requestsToSend.containsKey(req.partitionId) - || requestsToSend.get(req.partitionId).epoch < req.epoch) { - requestsToSend.put(req.partitionId, req); - } - } - } - - if (!requestsToSend.isEmpty()) { - // Call reviveBatch. Return null means Exception caught or - // SHUFFLE_NOT_REGISTERED - Map results = - shuffleClient.reviveBatch(shuffleId, mapIds, requestsToSend.values()); - if (results == null) { - for (ReviveRequest req : filteredRequests) { - req.reviveStatus = StatusCode.REVIVE_FAILED.getValue(); - } - } else { - for (ReviveRequest req : filteredRequests) { - if (shuffleClient.mapperEnded(shuffleId, req.mapId)) { - req.reviveStatus = StatusCode.SUCCESS.getValue(); + () -> { + try { + Map> urgentMap = new HashMap<>(); + Map> nonUrgentMap = new HashMap<>(); + do { + ArrayList batchRequests = new ArrayList<>(); + requestQueue.drainTo(batchRequests, batchSize); + for (ReviveRequest req : batchRequests) { + Set set = null; + if (req.urgent) { + set = urgentMap.computeIfAbsent(req.shuffleId, id -> new HashSet<>()); } else { - req.reviveStatus = results.get(req.partitionId); + set = nonUrgentMap.computeIfAbsent(req.shuffleId, id -> new HashSet<>()); } + set.add(req); + } + if (!urgentMap.isEmpty()) { + reviveInternal(urgentMap, true); + } + if (!nonUrgentMap.isEmpty()) { + reviveInternal(nonUrgentMap, false); } - } + // break the loop if remaining requests is less than half of + // `celeborn.client.push.revive.batchSize` + } while (requestQueue.size() > batchSize / 2); + } catch (Throwable e) { + logger.error("Exception when batchRevive: ", e); + throw e; + } + }, + interval, + interval, + TimeUnit.MILLISECONDS); + } + + public void reviveInternal(Map> shuffleMap, boolean urgent) { + for (Map.Entry> shuffleEntry : shuffleMap.entrySet()) { + // Call reviveBatch for requests in the same (appId, shuffleId) + int shuffleId = shuffleEntry.getKey(); + Set requests = shuffleEntry.getValue(); + processRequests(shuffleId, requests, urgent); + } + } + + public void processRequests(int shuffleId, Collection requests, boolean urgent) { + Set mapIds = new HashSet<>(); + ArrayList filteredRequests = new ArrayList<>(); + Map requestsToSend = new HashMap<>(); + + // Insert request that is not MapperEnded and with the max epoch + // into requestsToSend + Iterator iter = requests.iterator(); + while (iter.hasNext()) { + ReviveRequest req = iter.next(); + if ((urgent && shuffleClient.newerPartitionLocationExists(shuffleId, req.partitionId, req.clientMaxEpoch)) + || shuffleClient.mapperEnded(shuffleId, req.mapId)) { + req.reviveStatus = StatusCode.SUCCESS.getValue(); + } else { + filteredRequests.add(req); + mapIds.add(req.mapId); + if (!requestsToSend.containsKey(req.partitionId) + || requestsToSend.get(req.partitionId).clientMaxEpoch < req.clientMaxEpoch) { + requestsToSend.put(req.partitionId, req); + } + } + } + + ThreadPoolExecutor handler = urgent ? batchReviveRequestHandler : batchReportRequestHandler; + if (!requestsToSend.isEmpty()) { + handler.submit(() -> { + try { + // Call reviveBatch. Return null means Exception caught or + // SHUFFLE_NOT_REGISTERED + //Do not use WriterTracerHere because traceInfo is set afterward + long reviveStartTime = System.nanoTime(); + Map results = + shuffleClient.reviveBatch(shuffleId, mapIds, requestsToSend.values(), urgent); + long reviveCostTime = System.nanoTime() - reviveStartTime; + if (results == null) { + for (ReviveRequest req : filteredRequests) { + req.reviveStatus = StatusCode.REVIVE_FAILED.getValue(); + } + } else { + for (ReviveRequest req : filteredRequests) { + if (shuffleClient.mapperEnded(shuffleId, req.mapId)) { + req.reviveStatus = StatusCode.SUCCESS.getValue(); + } else { + req.reviveStatus = results.get(req.partitionId); } } - // break the loop if remaining requests is less than half of - // `celeborn.client.push.revive.batchSize` - } while (requestQueue.size() > batchSize / 2); - }, - interval, - interval, - TimeUnit.MILLISECONDS); + } + } catch (Throwable e) { + logger.error("Exception when processRequests: ", e); + throw e; + } + }); + } } public void addRequest(ReviveRequest request) { @@ -120,6 +161,7 @@ public void addRequest(ReviveRequest request) { // This sync is necessary to ensure the add action is atomic try { requestQueue.put(request); + logger.debug("Add urgent request: {}", request); } catch (InterruptedException e) { logger.error("Exception when put into requests!", e); } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index dde2b36c4d5..3ef5a3548ee 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -283,6 +283,14 @@ public abstract PartitionLocation registerMapPartitionTask( public abstract ConcurrentHashMap getPartitionLocation( int shuffleId, int numMappers, int numPartitions) throws CelebornIOException; + public boolean ensureRegistered(int shuffleId, int numMappers, int numPartitions) { + return false; + } + + public LocationManager getLocationManager() { + return null; + } + public abstract PushState getPushState(String mapKey); public abstract int getShuffleId( diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index c642223ead6..4d2bfac125d 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -105,10 +105,6 @@ public class ShuffleClientImpl extends ShuffleClient { // key: appShuffleIdentifier, value: shuffleId protected Map shuffleIdCache = JavaUtils.newConcurrentHashMap(); - // key: shuffleId, value: (partitionId, PartitionLocation) - final Map> reducePartitionMap = - JavaUtils.newConcurrentHashMap(); - // key: shuffleId, value: Set(mapId) protected final ConcurrentHashMap> mapperEndMap = JavaUtils.newConcurrentHashMap(); @@ -146,6 +142,8 @@ protected Compressor initialValue() { private final ReviveManager reviveManager; + private final LocationManager locationManager; + private final boolean dataPushFailureTrackingEnabled; public static class ReduceFileGroups { @@ -230,7 +228,7 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u ThreadUtils.newDaemonCachedThreadPool("celeborn-retry-sender", pushDataRetryThreads, 60); reviveManager = new ReviveManager(this, conf); - + locationManager = new LocationManager(this, reviveManager); logger.info("Created ShuffleClientImpl, appUniqueId: {}", appUniqueId); } @@ -291,26 +289,39 @@ private boolean isPushTargetWorkerExcluded( private void submitRetryPushData( int shuffleId, + int partitionId, + int mapId, + int attemptId, byte[] body, int batchId, PushDataRpcResponseCallback pushDataRpcResponseCallback, PushState pushState, - ReviveRequest request, + PartitionLocation oldLoc, + StatusCode cause, int remainReviveTimes, long dueTime) { - int mapId = request.mapId; - int attemptId = request.attemptId; - PartitionLocation loc = request.loc; - StatusCode cause = request.cause; - int partitionId = loc.getId(); long reviveWaitTime = dueTime - System.currentTimeMillis(); + long resubmitWaitTime = conf.clientRpcRequestPartitionLocationAskTimeout() + .duration() + .toMillis() - reviveWaitTime; final long delta = 50; long accumulatedTime = 0; - while (request.reviveStatus == StatusCode.REVIVE_INITIALIZED.getValue() - && accumulatedTime <= reviveWaitTime) { + + PartitionLocation loc = locationManager.getLocationOrReviveAsync(shuffleId, partitionId, mapId, attemptId, true, true); + while (loc == null && accumulatedTime <= reviveWaitTime) { try { Thread.sleep(delta); accumulatedTime += delta; + boolean hasActiveReviveRequest = locationManager.hasActiveReviveRequest(shuffleId, partitionId); + loc = locationManager.getLocationOrReviveAsync(shuffleId, partitionId, mapId, attemptId, false, true); + if (!hasActiveReviveRequest) { + if (loc == null) { + logger.warn("There is no active revive request, however, a new location has not yet been assigned for" + + " shuffle {} map {} attempt {} partition {} batch {} ", + shuffleId, mapId, attemptId, partitionId, batchId); + } + break; + } } catch (InterruptedException e) { logger.error("Interrupted while waiting for Revive result!"); Thread.currentThread().interrupt(); @@ -325,22 +336,15 @@ private void submitRetryPushData( partitionId, batchId, loc); - pushState.removeBatch(batchId, loc.hostAndPushPort()); - } else if (request.reviveStatus != StatusCode.SUCCESS.getValue()) { + pushState.removeBatch(batchId, oldLoc.hostAndPushPort()); + } else if (loc == null) { + StatusCode reviveStatus = locationManager.getReviveStatus(shuffleId, partitionId); pushDataRpcResponseCallback.onFailure( - new CelebornIOException( - cause - + " then revive but " - + StatusCode.REVIVE_FAILED - + ", revive status " - + request.reviveStatus - + "(" - + StatusCode.fromValue(request.reviveStatus) - + ")" - + ", old location: " - + request.loc)); + new CelebornIOException( + cause + " then revive but failed, revive status " + + reviveStatus + + ", old location: " + oldLoc)); } else { - PartitionLocation newLoc = reducePartitionMap.get(shuffleId).get(partitionId); logger.info( "Revive for push data success, new location for shuffle {} map {} attempt {} partition {} batch {} is location {}.", shuffleId, @@ -348,18 +352,18 @@ private void submitRetryPushData( attemptId, partitionId, batchId, - newLoc); - pushDataRpcResponseCallback.updateLatestPartition(newLoc); + loc); + try { - if (!isPushTargetWorkerExcluded(newLoc, pushDataRpcResponseCallback)) { + if (!isPushTargetWorkerExcluded(loc, pushDataRpcResponseCallback)) { if (!testRetryRevive || remainReviveTimes < 1) { assert dataClientFactory != null; TransportClient client = - dataClientFactory.createClient(newLoc.getHost(), newLoc.getPushPort(), partitionId); + dataClientFactory.createClient(loc.getHost(), loc.getPushPort(), partitionId); NettyManagedBuffer newBuffer = new NettyManagedBuffer(Unpooled.wrappedBuffer(body)); String shuffleKey = Utils.makeShuffleKey(appUniqueId, shuffleId); PushData newPushData = - new PushData(PRIMARY_MODE, shuffleKey, newLoc.getUniqueId(), newBuffer); + new PushData(PRIMARY_MODE, shuffleKey, loc.getUniqueId(), newBuffer); client.pushData(newPushData, pushDataTimeout, pushDataRpcResponseCallback); } else { throw new RuntimeException( @@ -376,7 +380,7 @@ private void submitRetryPushData( attemptId, partitionId, batchId, - newLoc, + loc, e); if (e instanceof InterruptedException) { pushDataRpcResponseCallback.onFailure(e); @@ -388,22 +392,21 @@ private void submitRetryPushData( } } - public ReviveRequest[] addAndGetReviveRequests( - int shuffleId, - int mapId, - int attemptId, - ArrayList batches, - StatusCode cause) { - ReviveRequest[] reviveRequests = new ReviveRequest[batches.size()]; + private void submitRetryPushMergedData( + PushState pushState, + int shuffleId, + int mapId, + int attemptId, + ArrayList batches, + StatusCode cause, + Integer oldGroupedBatchId, + int remainReviveTimes, + long reviveResponseDueTime) { + List causeList = new ArrayList<>(); for (int i = 0; i < batches.size(); i++) { - DataBatches.DataBatch batch = batches.get(i); - PartitionLocation loc = batch.loc; - ReviveRequest reviveRequest = - new ReviveRequest(shuffleId, mapId, attemptId, loc.getId(), loc.getEpoch(), loc, cause); - reviveManager.addRequest(reviveRequest); - reviveRequests[i] = reviveRequest; + causeList.add(cause); } - return reviveRequests; + submitRetryPushMergedData(pushState, shuffleId, mapId, attemptId, batches, causeList, oldGroupedBatchId, remainReviveTimes, reviveResponseDueTime); } private void submitRetryPushMergedData( @@ -412,58 +415,99 @@ private void submitRetryPushMergedData( int mapId, int attemptId, ArrayList batches, - StatusCode cause, + List causeList, Integer oldGroupedBatchId, - ReviveRequest[] reviveRequests, int remainReviveTimes, long reviveResponseDueTime) { + long reviveWaitTime = reviveResponseDueTime - System.currentTimeMillis(); + long resubmitWaitTime = conf.clientRpcRequestPartitionLocationAskTimeout() + .duration() + .toMillis() - reviveWaitTime; + HashMap, DataBatches> newDataBatchesMap = new HashMap<>(); ArrayList reviveFailedBatchesMap = new ArrayList<>(); + ArrayList reviveFailedCauseList = new ArrayList<>(); + + String oldAddressPair = ""; + StringBuilder dataBatchReviveInfos = new StringBuilder(); + for (int i = 0; i < batches.size(); i++) { + DataBatches.DataBatch batch = batches.get(i); + StatusCode cause = causeList.get(i); + oldAddressPair = batch.loc.hostAndPushPort(); + dataBatchReviveInfos.append( + String.format( + "(batchId=%d, partitionId=%d, epochId=%d, cause=%s)", + batch.batchId, + batch.loc.getId(), + batch.loc.getEpoch(), + cause)); + } + logger.error("Push merged data to {} failed for shuffle {} map {} attempt {} groupedBatch {}, split batches {}, remain revive times {}.", + oldAddressPair, + shuffleId, + mapId, + attemptId, + oldGroupedBatchId, + dataBatchReviveInfos, + remainReviveTimes); - long reviveWaitTime = reviveResponseDueTime - System.currentTimeMillis(); final long delta = 50; long accumulatedTime = 0; int index = 0; - while (index < reviveRequests.length && accumulatedTime <= reviveWaitTime) { - ReviveRequest request = reviveRequests[index]; + boolean doRevive = true; + + while (index < batches.size() && accumulatedTime <= reviveWaitTime) { DataBatches.DataBatch batch = batches.get(index); - if (request.reviveStatus != StatusCode.REVIVE_INITIALIZED.getValue()) { + int partitionId = batch.loc.getId(); + PartitionLocation loc = locationManager.getLocationOrReviveAsync(shuffleId, partitionId, mapId, attemptId, doRevive, true); + doRevive = false; + if (!locationManager.hasActiveReviveRequest(shuffleId, partitionId)) { if (mapperEnded(shuffleId, mapId)) { logger.debug( - "Revive for push merged data success, but the mapper already ended for shuffle {} map {} attempt {} partition {} batch {}.", - shuffleId, - mapId, - attemptId, - request.partitionId, - oldGroupedBatchId); - } else if (request.reviveStatus == StatusCode.SUCCESS.getValue()) { - PartitionLocation newLoc = reducePartitionMap.get(shuffleId).get(request.partitionId); - DataBatches newDataBatches = - newDataBatchesMap.computeIfAbsent(genAddressPair(newLoc), (s) -> new DataBatches()); - newDataBatches.addDataBatch(newLoc, batch.batchId, batch.body); + "Revive for push merged data success, but the mapper already ended for shuffle {} map {} attempt {} partition {} batch {}.", + shuffleId, + mapId, + attemptId, + partitionId, + oldGroupedBatchId); + } else if (loc != null) { + logger.info( + "Revive for push merged data success, new location for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {} is location {}.", + shuffleId, + mapId, + attemptId, + partitionId, + oldGroupedBatchId, + batch.batchId, + loc); + DataBatches newDataBatches = newDataBatchesMap.computeIfAbsent(genAddressPair(loc), (s) -> new DataBatches()); + newDataBatches.addDataBatch(loc, batch.batchId, batch.body); } else { + logger.warn("There is no active revive request, however, a new location has not yet been assigned for" + + " shuffle {} map {} attempt {} partition {} groupedBatch {}", + shuffleId, mapId, attemptId, partitionId, oldGroupedBatchId); if (remainReviveTimes > 0) { reviveFailedBatchesMap.add(batch); + reviveFailedCauseList.add(causeList.get(index)); } else { String errorMsg = - String.format( - "Revive failed while pushing merged for shuffle %d map %d attempt %d partition %d batch %d location %s.", - shuffleId, mapId, attemptId, request.partitionId, oldGroupedBatchId, batch.loc); + String.format( + "Revive failed while pushing merged for shuffle %d map %d attempt %d partition %d batch %d location %s.", + shuffleId, mapId, attemptId, partitionId, oldGroupedBatchId, batch.loc); + StatusCode reviveStatus = locationManager.getReviveStatus(shuffleId, partitionId); pushState.exception.compareAndSet( - null, - new CelebornIOException( - errorMsg, + null, new CelebornIOException( - cause - + " then revive but " - + request.reviveStatus - + "(" - + StatusCode.fromValue(request.reviveStatus) - + ")"))); + errorMsg, + new CelebornIOException( + causeList.get(index) + + " then revive but failed, revive status " + + reviveStatus))); return; } } index++; + doRevive = true; } else { try { Thread.sleep(delta); @@ -475,27 +519,26 @@ private void submitRetryPushMergedData( } } - for (int i = index; i < reviveRequests.length; i++) { - ReviveRequest request = reviveRequests[index]; + for (int i = index; i < batches.size(); i++) { DataBatches.DataBatch batch = batches.get(i); if (remainReviveTimes > 0) { reviveFailedBatchesMap.add(batch); + reviveFailedCauseList.add(causeList.get(i)); } else { + int partitionId = batch.loc.getId(); String errorMsg = - String.format( - "Revive failed while pushing merged for shuffle %d map %d attempt %d partition %d batch %d location %s.", - shuffleId, mapId, attemptId, request.partitionId, oldGroupedBatchId, batch.loc); + String.format( + "Revive failed while pushing merged for shuffle %d map %d attempt %d partition %d batch %d location %s.", + shuffleId, mapId, attemptId, partitionId, oldGroupedBatchId, batch.loc); + StatusCode reviveStatus = locationManager.getReviveStatus(shuffleId, partitionId); pushState.exception.compareAndSet( - null, - new CelebornIOException( - errorMsg, + null, new CelebornIOException( - cause - + " then revive but " - + request.reviveStatus - + "(" - + StatusCode.fromValue(request.reviveStatus) - + ")"))); + errorMsg, + new CelebornIOException( + causeList.get(index) + + " then revive but failed, revive status " + + reviveStatus))); return; } } @@ -515,8 +558,6 @@ private void submitRetryPushMergedData( if (reviveFailedBatchesMap.isEmpty()) { pushState.removeBatch(oldGroupedBatchId, batches.get(0).loc.hostAndPushPort()); } else { - ReviveRequest[] requests = - addAndGetReviveRequests(shuffleId, mapId, attemptId, reviveFailedBatchesMap, cause); pushDataRetryPool.submit( () -> submitRetryPushMergedData( @@ -525,9 +566,8 @@ private void submitRetryPushMergedData( mapId, attemptId, reviveFailedBatchesMap, - cause, + reviveFailedCauseList, oldGroupedBatchId, - requests, remainReviveTimes - 1, System.currentTimeMillis() + conf.clientRpcRequestPartitionLocationAskTimeout().duration().toMillis())); @@ -542,7 +582,7 @@ private Pair genAddressPair(PartitionLocation loc) { } } - private ConcurrentHashMap registerShuffle( + private ConcurrentHashMap> registerShuffle( int shuffleId, int numMappers, int numPartitions) throws CelebornIOException { return registerShuffleInternal( shuffleId, @@ -578,7 +618,8 @@ public PartitionLocation registerMapPartitionTask( attemptId, partitionId, numMappers); - ConcurrentHashMap partitionLocationMap = + // TODO check + ConcurrentHashMap> partitionLocationMap = registerShuffleInternal( shuffleId, numMappers, @@ -595,29 +636,47 @@ public PartitionLocation registerMapPartitionTask( conf.clientRpcRegisterShuffleAskTimeout(), ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class))); - return partitionLocationMap.get(partitionId); + return partitionLocationMap.get(partitionId).get(0); } @Override public ConcurrentHashMap getPartitionLocation( int shuffleId, int numMappers, int numPartitions) throws CelebornIOException { - try { - return reducePartitionMap.computeIfAbsent( - shuffleId, - (id) -> { - try { - return registerShuffle(shuffleId, numMappers, numPartitions); - } catch (CelebornIOException e) { - throw new RuntimeException(e); - } - }); - } catch (RuntimeException e) { - if (e.getCause() instanceof CelebornIOException) { - throw (CelebornIOException) e.getCause(); - } else { - throw e; + // TODO only UT related usages, fix later + return null; +// try { +// return reducePartitionMap.computeIfAbsent( +// shuffleId, +// (id) -> { +// try { +// return registerShuffle(shuffleId, numMappers, numPartitions); +// } catch (CelebornIOException e) { +// throw new RuntimeException(e); +// } +// }); +// } catch (RuntimeException e) { +// if (e.getCause() instanceof CelebornIOException) { +// throw (CelebornIOException) e.getCause(); +// } else { +// throw e; +// } +// } + } + + @Override + public boolean ensureRegistered(int shuffleId, int numMappers, int numPartitions) { + if (!locationManager.registered(shuffleId)) { + try { + ConcurrentHashMap> map = + registerShuffle(shuffleId, numMappers, numPartitions); + locationManager.registerShuffleLocs(shuffleId, map); + } catch (CelebornIOException e) { + //TODO log exception + return false; } } + + return true; } @Override @@ -678,7 +737,7 @@ public boolean reportBarrierTaskFailure(int appShuffleId, String appShuffleIdent return pbReportBarrierStageAttemptFailureResponse.getSuccess(); } - private ConcurrentHashMap registerShuffleInternal( + private ConcurrentHashMap> registerShuffleInternal( int shuffleId, int numMappers, int numPartitions, @@ -691,16 +750,20 @@ private ConcurrentHashMap registerShuffleInternal( PbRegisterShuffleResponse response = callable.call(); StatusCode respStatus = StatusCode.fromValue(response.getStatus()); if (StatusCode.SUCCESS.equals(respStatus)) { - ConcurrentHashMap result = JavaUtils.newConcurrentHashMap(); - Tuple2, List> locations = - PbSerDeUtils.fromPbPackedPartitionLocationsPair( - response.getPackedPartitionLocationsPair()); - for (PartitionLocation location : locations._1) { - pushExcludedWorkers.remove(location.hostAndPushPort()); - if (location.hasPeer()) { - pushExcludedWorkers.remove(location.getPeer().hostAndPushPort()); + ConcurrentHashMap> result = JavaUtils.newConcurrentHashMap(); + for (int i = 0; i < response.getPartitionLocationsList().size(); i++) { + Tuple2, List> locations = + PbSerDeUtils.fromPbPackedPartitionLocationsPair( + response.getPackedPartitionLocationsPair()); + for (PartitionLocation location : locations._1) { + pushExcludedWorkers.remove(location.hostAndPushPort()); + if (location.hasPeer()) { + pushExcludedWorkers.remove(location.getPeer().hostAndPushPort()); + } + List list = result.computeIfAbsent(location.getId(), x -> new ArrayList<>()); + list.add(location); + result.put(location.getId(), list); } - result.put(location.getId(), location); } return result; } else if (StatusCode.SLOT_NOT_AVAILABLE.equals(respStatus)) { @@ -773,33 +836,14 @@ protected void limitZeroInFlight(String mapKey, PushState pushState) throws IOEx /** * Check if a newer PartitionLocation(with larger epoch) exists in local cache. * - * @param shuffleMap The mapping between shuffle id and partition location. + * @param shuffleId The shuffle id. * @param partitionId The id of partition. * @param epoch The epoch of revive. - * @param wait Whether to wait for some time for a newer partition location. * @return whether newer partition location exists in local cache. */ boolean newerPartitionLocationExists( - Map shuffleMap, int partitionId, int epoch, boolean wait) { - PartitionLocation currentLocation = shuffleMap.get(partitionId); - if (currentLocation != null && currentLocation.getEpoch() > epoch) { - return true; - } else if (wait) { - long sleepTimeMs = RND.nextInt(50); - if (sleepTimeMs > 30) { - try { - TimeUnit.MILLISECONDS.sleep(sleepTimeMs); - } catch (InterruptedException e) { - logger.error("Waiting revived location was interrupted.", e); - Thread.currentThread().interrupt(); - } - } - - currentLocation = shuffleMap.get(partitionId); - return currentLocation != null && currentLocation.getEpoch() > epoch; - } else { - return false; - } + int shuffleId, int partitionId, int epoch) { + return locationManager.newerPartitionLocationExists(shuffleId, partitionId, epoch); } void excludeWorkerByCause(StatusCode cause, PartitionLocation oldLocation) { @@ -820,61 +864,32 @@ void excludeWorkerByCause(StatusCode cause, PartitionLocation oldLocation) { } } - private boolean revive( - int shuffleId, - int mapId, - int attemptId, - int partitionId, - int epoch, - PartitionLocation oldLocation, - StatusCode cause) { - excludeWorkerByCause(cause, oldLocation); - - Set mapIds = new HashSet<>(); - mapIds.add(mapId); - List requests = new ArrayList<>(); - ReviveRequest req = - new ReviveRequest(shuffleId, mapId, attemptId, partitionId, epoch, oldLocation, cause); - requests.add(req); - Map results = reviveBatch(shuffleId, mapIds, requests); - - if (mapperEnded(shuffleId, mapId)) { - logger.debug( - "Revive success, but the mapper ended for shuffle {} map {} attempt {} partition {}, just return true(Assume revive successfully).", - shuffleId, - mapId, - attemptId, - partitionId); - return true; - } else { - return results != null - && results.containsKey(partitionId) - && results.get(partitionId) == StatusCode.SUCCESS.getValue(); - } - } - /** @return partitionId -> StatusCode#getValue */ Map reviveBatch( - int shuffleId, Set mapIds, Collection requests) { + int shuffleId, Set mapIds, Collection requests, boolean urgent) { // partitionId -> StatusCode#getValue Map results = new HashMap<>(); - // Local cached map of (partitionId -> PartitionLocation) - ConcurrentHashMap partitionLocationMap = - reducePartitionMap.get(shuffleId); - Map oldLocMap = new HashMap<>(); Iterator iter = requests.iterator(); while (iter.hasNext()) { ReviveRequest req = iter.next(); oldLocMap.put(req.partitionId, req.loc); } + try { - PbChangeLocationResponse response = - lifecycleManagerRef.askSync( - Revive$.MODULE$.apply(shuffleId, mapIds, requests), - conf.clientRpcRequestPartitionLocationAskTimeout(), - ClassTag$.MODULE$.apply(PbChangeLocationResponse.class)); + PbChangeLocationResponse response; + if (urgent) { + response = lifecycleManagerRef.askSync( + Revive$.MODULE$.apply(shuffleId, mapIds, requests), + conf.clientRpcRequestPartitionLocationAskTimeout(), + ClassTag$.MODULE$.apply(PbChangeLocationResponse.class)); + } else { + response = lifecycleManagerRef.askSync( + PartitionSplitReport$.MODULE$.apply(shuffleId, mapIds, requests), + conf.clientRpcRequestPartitionLocationAskTimeout(), + ClassTag$.MODULE$.apply(PbChangeLocationResponse.class)); + } for (int i = 0; i < response.getEndedMapIdCount(); i++) { int mapId = response.getEndedMapId(i); @@ -892,13 +907,23 @@ Map reviveBatch( } if (StatusCode.SUCCESS.getValue() == statusCode) { - PartitionLocation loc = - PbSerDeUtils.fromPbPartitionLocation(partitionInfo.getPartition()); - partitionLocationMap.put(partitionId, loc); - pushExcludedWorkers.remove(loc.hostAndPushPort()); - if (loc.hasPeer()) { - pushExcludedWorkers.remove(loc.getPeer().hostAndPushPort()); + List pbPartitionLocations = partitionInfo.getPartitionList(); + if (pbPartitionLocations.isEmpty()) { + continue; + } + ArrayList locations = new ArrayList<>(pbPartitionLocations.size()); + for (PbPartitionLocation pbPartitionLoc : pbPartitionLocations) { + PartitionLocation loc = PbSerDeUtils.fromPbPartitionLocation(pbPartitionLoc); + if (locationManager.locationExists(shuffleId, loc.getId(), loc.getEpoch())) { + continue; + } + locations.add(loc); + pushExcludedWorkers.remove(loc.hostAndPushPort()); + if (loc.hasPeer()) { + pushExcludedWorkers.remove(loc.getPeer().hostAndPushPort()); + } } + locationManager.updateLocation(shuffleId, partitionId, locations); } else if (StatusCode.STAGE_ENDED.getValue() == statusCode) { stageEndShuffleSet.add(shuffleId); return results; @@ -916,7 +941,11 @@ Map reviveBatch( requests.forEach( (req) -> { partitionIds.append(req.partitionId).append(","); - epochs.append(req.epoch).append(","); + Integer epochId = null; + if (req.loc != null) { + epochId = req.loc.getEpoch(); + } + epochs.append(epochId).append(","); }); logger.error( "Exception raised while reviving for shuffle {} partitionIds {} epochs {}.", @@ -965,23 +994,23 @@ public int pushOrMergeData( return 0; } // register shuffle if not registered - final ConcurrentHashMap map = - getPartitionLocation(shuffleId, numMappers, numPartitions); + boolean registered = ensureRegistered(shuffleId, numMappers, numPartitions); + if (!registered) { + throw new CelebornIOException("Register shuffle failed for shuffle " + shuffleId + "."); + } // get location // If rerun or speculation task running after LifecycleManager call stageEnd, // register shuffle will return an empty location map, client need revive for a new location. - if (!map.containsKey(partitionId)) { - if (!revive( - shuffleId, - mapId, - attemptId, - partitionId, - -1, - null, - StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_PRIMARY)) { + if (!locationManager.exists(shuffleId, partitionId)) { + if (!locationManager.reviveSync( + shuffleId, + partitionId, + mapId, + attemptId, + StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_PRIMARY)) { throw new CelebornIOException( - String.format("Revive for shuffle %s partition %d failed.", shuffleId, partitionId)); + String.format("Revive for shuffle %s partition %d failed.", shuffleId, partitionId)); } } @@ -999,7 +1028,7 @@ public int pushOrMergeData( return 0; } - final PartitionLocation loc = map.get(partitionId); + PartitionLocation loc = locationManager.getLocationOrReviveAsync(shuffleId, partitionId, mapId, attemptId, false, false); if (loc == null) { throw new CelebornIOException( String.format( @@ -1095,26 +1124,15 @@ public void onSuccess(ByteBuffer response) { attemptId, partitionId, nextBatchId); - if (!newerPartitionLocationExists( - reducePartitionMap.get(shuffleId), partitionId, latest.getEpoch(), false)) { - ReviveRequest reviveRequest = - new ReviveRequest( - shuffleId, - mapId, - attemptId, - partitionId, - latest.getEpoch(), - latest, - StatusCode.SOFT_SPLIT); - reviveManager.addRequest(reviveRequest); - } + locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, latest, StatusCode.SOFT_SPLIT); pushState.onSuccess(latest.hostAndPushPort()); pushState.removeBatch(nextBatchId, latest.hostAndPushPort()); callback.onSuccess(response); } else if (reason == StatusCode.HARD_SPLIT.getValue()) { logger.debug( - "Push data to {} hard split required for shuffle {} map {} attempt {} partition {} batch {}.", + "Push data to {} epoch {}, hard split required for shuffle {} map {} attempt {} partition {} batch {}.", latest.hostAndPushPort(), + latest.getEpoch(), shuffleId, mapId, attemptId, @@ -1124,16 +1142,7 @@ public void onSuccess(ByteBuffer response) { pushState.addFailedBatch( latest.getUniqueId(), new PushFailedBatch(mapId, attemptId, nextBatchId)); } - ReviveRequest reviveRequest = - new ReviveRequest( - shuffleId, - mapId, - attemptId, - partitionId, - latest.getEpoch(), - latest, - StatusCode.HARD_SPLIT); - reviveManager.addRequest(reviveRequest); + locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, latest, StatusCode.HARD_SPLIT); long dueTime = System.currentTimeMillis() + conf.clientRpcRequestPartitionLocationAskTimeout() @@ -1143,11 +1152,15 @@ public void onSuccess(ByteBuffer response) { () -> submitRetryPushData( shuffleId, + partitionId, + mapId, + attemptId, body, nextBatchId, this, pushState, - reviveRequest, + latest, + StatusCode.HARD_SPLIT, remainReviveTimes, dueTime)); } else if (reason == StatusCode.PUSH_DATA_SUCCESS_PRIMARY_CONGESTED.getValue()) { @@ -1225,10 +1238,7 @@ public void onFailure(Throwable e) { // async retry push data if (!mapperEnded(shuffleId, mapId)) { remainReviveTimes = remainReviveTimes - 1; - ReviveRequest reviveRequest = - new ReviveRequest( - shuffleId, mapId, attemptId, partitionId, latest.getEpoch(), latest, cause); - reviveManager.addRequest(reviveRequest); + locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, latest, cause); long dueTime = System.currentTimeMillis() + conf.clientRpcRequestPartitionLocationAskTimeout().duration().toMillis(); @@ -1236,11 +1246,15 @@ public void onFailure(Throwable e) { () -> submitRetryPushData( shuffleId, + partitionId, + mapId, + attemptId, body, nextBatchId, this, pushState, - reviveRequest, + latest, + cause, remainReviveTimes, dueTime)); } else { @@ -1491,6 +1505,7 @@ public void onSuccess(ByteBuffer response) { byte reason = response.get(); if (reason == StatusCode.HARD_SPLIT.getValue()) { ArrayList batchesNeedResubmit; + ArrayList causeList = new ArrayList<>(); if (response.remaining() > 0) { batchesNeedResubmit = new ArrayList<>(); PbPushMergedDataSplitPartitionInfo partitionInfo; @@ -1506,7 +1521,8 @@ public void onSuccess(ByteBuffer response) { StringBuilder dataBatchReviveInfos = new StringBuilder(); for (int i = 0; i < splitPartitionIndexes.size(); i++) { int partitionIndex = splitPartitionIndexes.get(i); - int batchId = batches.get(partitionIndex).batchId; + DataBatches.DataBatch currentBatch = batches.get(partitionIndex); + int batchId = currentBatch.batchId; dataBatchReviveInfos.append( String.format( "(batchId=%d, partitionId=%d, cause=%s)", @@ -1514,22 +1530,11 @@ public void onSuccess(ByteBuffer response) { partitionIds[partitionIndex], StatusCode.fromValue(statusCodeList.get(i).byteValue()))); if (statusCodeList.get(i) == StatusCode.SOFT_SPLIT.getValue()) { - PartitionLocation loc = batches.get(partitionIndex).loc; - if (!newerPartitionLocationExists( - reducePartitionMap.get(shuffleId), loc.getId(), loc.getEpoch(), false)) { - ReviveRequest reviveRequest = - new ReviveRequest( - shuffleId, - mapId, - attemptId, - loc.getId(), - loc.getEpoch(), - loc, - StatusCode.SOFT_SPLIT); - reviveManager.addRequest(reviveRequest); - } + locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, currentBatch.loc, StatusCode.SOFT_SPLIT); } else { + locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, currentBatch.loc, StatusCode.HARD_SPLIT); batchesNeedResubmit.add(batches.get(partitionIndex)); + causeList.add(StatusCode.HARD_SPLIT); } } logger.info( @@ -1546,6 +1551,10 @@ public void onSuccess(ByteBuffer response) { // but will not include a PbPushMergedDataSplitPartitionInfo. // For backward compatibility, all batches must be resubmitted. batchesNeedResubmit = batches; + for (int i = 0; i < numBatches; i++) { + causeList.add(StatusCode.HARD_SPLIT); + locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, batches.get(i).loc, StatusCode.HARD_SPLIT); + } logger.info( "Push merged data to {} hard split required for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.", addressPair, @@ -1567,9 +1576,6 @@ public void onSuccess(ByteBuffer response) { new PushFailedBatch(mapId, attemptId, resubmitBatch.batchId)); } } - ReviveRequest[] requests = - addAndGetReviveRequests( - shuffleId, mapId, attemptId, batchesNeedResubmit, StatusCode.HARD_SPLIT); pushDataRetryPool.submit( () -> submitRetryPushMergedData( @@ -1578,9 +1584,8 @@ public void onSuccess(ByteBuffer response) { mapId, attemptId, batchesNeedResubmit, - StatusCode.HARD_SPLIT, + causeList, groupedBatchId, - requests, remainReviveTimes, System.currentTimeMillis() + conf.clientRpcRequestPartitionLocationAskTimeout() @@ -1657,8 +1662,9 @@ public void onFailure(Throwable e) { remainReviveTimes, e); if (!mapperEnded(shuffleId, mapId)) { - ReviveRequest[] requests = - addAndGetReviveRequests(shuffleId, mapId, attemptId, batches, cause); + for (DataBatches.DataBatch batch : batches) { + locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, batch.loc, cause); + } pushDataRetryPool.submit( () -> submitRetryPushMergedData( @@ -1669,7 +1675,6 @@ public void onFailure(Throwable e) { batches, cause, groupedBatchId, - requests, remainReviveTimes - 1, System.currentTimeMillis() + conf.clientRpcRequestPartitionLocationAskTimeout() @@ -1779,7 +1784,7 @@ public void cleanup(int shuffleId, int mapId, int attemptId) { @Override public boolean cleanupShuffle(int shuffleId) { // clear status - reducePartitionMap.remove(shuffleId); + locationManager.removeShuffle(shuffleId); reduceFileGroupsMap.remove(shuffleId); mapperEndMap.remove(shuffleId); stageEndShuffleSet.remove(shuffleId); @@ -2075,6 +2080,11 @@ private StatusCode getPushDataFailCause(String message) { return cause; } + @Override + public LocationManager getLocationManager() { + return locationManager; + } + @VisibleForTesting @Override public TransportClientFactory getDataClientFactory() { diff --git a/client/src/main/java/org/apache/celeborn/client/write/DataPushQueue.java b/client/src/main/java/org/apache/celeborn/client/write/DataPushQueue.java index aec6a69b5d4..7df04a9158d 100644 --- a/client/src/main/java/org/apache/celeborn/client/write/DataPushQueue.java +++ b/client/src/main/java/org/apache/celeborn/client/write/DataPushQueue.java @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import org.apache.celeborn.client.LocationManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,6 +48,8 @@ public class DataPushQueue { private final PushState pushState; private final DataPusher dataPusher; private final int shuffleId; + private final int mapId; + private final int attemptId; private final int numMappers; private final int numPartitions; private final ShuffleClient client; @@ -68,6 +71,8 @@ public DataPushQueue( this.client = client; this.dataPusher = dataPusher; final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); + this.mapId = mapId; + this.attemptId = attemptId; this.pushState = client.getPushState(mapKey); this.takeTaskWaitIntervalMs = conf.clientPushTakeTaskWaitIntervalMs(); this.takeTaskMaxWaitAttempts = conf.clientPushTakeTaskMaxWaitAttempts(); @@ -88,17 +93,17 @@ public ArrayList takePushTasks() throws IOException, InterruptedExcept // takeTaskWaitTimeMs // in last loop workerCapacity.clear(); - Map partitionLocationMap = - client.getPartitionLocation(shuffleId, numMappers, numPartitions); - if (partitionLocationMap == null) { + boolean registered = client.ensureRegistered(shuffleId, numMappers, numPartitions); + if (!registered) { tasks.addAll(workingQueue); workingQueue.clear(); } else { Iterator iterator = workingQueue.iterator(); + LocationManager locationManager = client.getLocationManager(); while (iterator.hasNext()) { PushTask task = iterator.next(); int partitionId = task.getPartitionId(); - PartitionLocation loc = partitionLocationMap.get(partitionId); + PartitionLocation loc = locationManager.getLocationOrReviveAsync(shuffleId, partitionId, mapId, attemptId, true, false); // According to CELEBORN-560, call rerun task and speculative task after LifecycleManager // handle StageEnd will return empty PartitionLocation map, here loc can be null if (loc != null) { diff --git a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala index d306546b305..02f67df7e1f 100644 --- a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala +++ b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala @@ -40,7 +40,7 @@ case class ChangeLocationsCallContext( extends RequestLocationCallContext with Logging { val endedMapIds = new util.HashSet[Integer]() val newLocs = - JavaUtils.newConcurrentHashMap[Integer, (StatusCode, Boolean, PartitionLocation)]( + JavaUtils.newConcurrentHashMap[Integer, (StatusCode, Boolean, Seq[PartitionLocation])]( partitionCount) def markMapperEnd(mapId: Int): Unit = this.synchronized { @@ -55,7 +55,8 @@ case class ChangeLocationsCallContext( if (newLocs.containsKey(partitionId)) { logError(s"PartitionId $partitionId already exists!") } - newLocs.put(partitionId, (status, available, partitionLocationOpt.getOrElse(null))) + //TODO fix later +// newLocs.put(partitionId, (status, available, partitionLocationOpt.getOrElse(null))) if (newLocs.size() == partitionCount || StatusCode.SHUFFLE_NOT_REGISTERED == status || StatusCode.STAGE_ENDED == status) { diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/ReviveRequest.java b/common/src/main/java/org/apache/celeborn/common/protocol/ReviveRequest.java index 0d001adfa56..e10b7174267 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/ReviveRequest.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/ReviveRequest.java @@ -18,32 +18,64 @@ package org.apache.celeborn.common.protocol; import org.apache.celeborn.common.protocol.message.StatusCode; +import java.util.Objects; public class ReviveRequest { public int shuffleId; public int mapId; public int attemptId; public int partitionId; - public int epoch; public PartitionLocation loc; + public int clientMaxEpoch; public StatusCode cause; + public boolean urgent; public volatile int reviveStatus; public ReviveRequest( - int shuffleId, - int mapId, - int attemptId, - int partitionId, - int epoch, - PartitionLocation loc, - StatusCode cause) { + int shuffleId, + int mapId, + int attemptId, + int partitionId, + PartitionLocation loc, + StatusCode cause, + int clientMaxEpoch, + boolean urgent) { this.shuffleId = shuffleId; this.mapId = mapId; this.attemptId = attemptId; this.partitionId = partitionId; - this.epoch = epoch; this.loc = loc; + this.clientMaxEpoch = clientMaxEpoch; this.cause = cause; + this.urgent = urgent; reviveStatus = StatusCode.REVIVE_INITIALIZED.getValue(); } + + @Override + public String toString() { + return "ReviveRequest{" + + "shuffleId=" + shuffleId + + ", mapId=" + mapId + + ", attemptId=" + attemptId + + ", partitionId=" + partitionId + + ", loc=" + loc + + ", clientMaxEpoch=" + clientMaxEpoch + + ", cause=" + cause + + ", urgent=" + urgent + + ", reviveStatus=" + reviveStatus + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ReviveRequest that = (ReviveRequest) o; + return shuffleId == that.shuffleId && mapId == that.mapId && attemptId == that.attemptId && partitionId == that.partitionId && clientMaxEpoch == that.clientMaxEpoch && urgent == that.urgent && reviveStatus == that.reviveStatus && loc == that.loc && cause == that.cause; + } + + @Override + public int hashCode() { + return Objects.hash(shuffleId, mapId, attemptId, partitionId, loc, clientMaxEpoch, cause, urgent, reviveStatus); + } } diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java index 7c1df8d0720..1843a8c72a0 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java @@ -91,7 +91,8 @@ public enum StatusCode { OPEN_STREAM_FAILED(51), SEGMENT_START_FAIL_REPLICA(52), SEGMENT_START_FAIL_PRIMARY(53), - NO_SPLIT(54); + NO_SPLIT(54), + URGENT_REVIVE(55); private final byte value; diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index acf355756c8..469571de189 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -112,6 +112,7 @@ enum MessageType { REVISE_LOST_SHUFFLES = 89; REVISE_LOST_SHUFFLES_RESPONSE = 90; PUSH_MERGED_DATA_SPLIT_PARTITION_INFO = 91; + PARTITION_SPLIT_REPORT = 92; } enum StreamType { @@ -323,6 +324,7 @@ message PbRevivePartitionInfo { int32 epoch = 2; PbPartitionLocation partition = 3; int32 status = 4; + int32 clientMaxEpoch = 5; } message PbRevive { @@ -331,10 +333,16 @@ message PbRevive { repeated PbRevivePartitionInfo partitionInfo = 3; } +message PbPartitionSplitReport { + int32 shuffleId = 1; + repeated int32 mapId = 2; + repeated PbRevivePartitionInfo partitionInfo = 3; +} + message PbChangeLocationPartitionInfo { int32 partitionId = 1; int32 status = 2; - PbPartitionLocation partition = 3; + repeated PbPartitionLocation partition = 3; bool oldAvailable = 4; } diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 57831f213c8..79148f21d92 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -221,10 +221,12 @@ object ControlMessages extends Logging { reviveRequests.asScala.foreach { req => val partitionInfoBuilder = PbRevivePartitionInfo.newBuilder() .setPartitionId(req.partitionId) - .setEpoch(req.epoch) + .setClientMaxEpoch(req.clientMaxEpoch) .setStatus(req.cause.getValue) if (req.loc != null) { - partitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(req.loc)) + partitionInfoBuilder + .setEpoch(req.loc.getEpoch) + .setPartition(PbSerDeUtils.toPbPartitionLocation(req.loc)) } builder.addPartitionInfo(partitionInfoBuilder.build()) } @@ -233,6 +235,34 @@ object ControlMessages extends Logging { } } + object PartitionSplitReport { + def apply( + shuffleId: Int, + mapIds: util.Set[Integer], + reviveRequests: util.Collection[ReviveRequest]): PbPartitionSplitReport = { + val builder = PbPartitionSplitReport.newBuilder() + .setShuffleId(shuffleId) + .addAllMapId(mapIds) + + reviveRequests.asScala.foreach { req => + val partitionInfoBuilder = PbRevivePartitionInfo.newBuilder() + .setPartitionId(req.partitionId) + .setClientMaxEpoch(req.clientMaxEpoch) + .setStatus(req.cause.getValue) + + if (req.loc != null) { + partitionInfoBuilder + .setEpoch(req.loc.getEpoch) + .setPartition(PbSerDeUtils.toPbPartitionLocation(req.loc)) + } + + builder.addPartitionInfo(partitionInfoBuilder.build()) + } + + builder.build() + } + } + object PartitionSplit { def apply( shuffleId: Int, @@ -250,17 +280,18 @@ object ControlMessages extends Logging { object ChangeLocationResponse { def apply( mapIds: util.Set[Integer], - newLocs: util.Map[Integer, (StatusCode, Boolean, PartitionLocation)]) + newLocs: util.Map[Integer, (StatusCode, Boolean, Seq[PartitionLocation])]) : PbChangeLocationResponse = { val builder = PbChangeLocationResponse.newBuilder() builder.addAllEndedMapId(mapIds) - newLocs.asScala.foreach { case (partitionId, (status, available, loc)) => + newLocs.asScala.foreach { case (partitionId, (status, available, locs)) => val pbChangeLocationPartitionInfoBuilder = PbChangeLocationPartitionInfo.newBuilder() .setPartitionId(partitionId) .setStatus(status.getValue) .setOldAvailable(available) - if (loc != null) { - pbChangeLocationPartitionInfoBuilder.setPartition(PbSerDeUtils.toPbPartitionLocation(loc)) + if (locs != null) { + locs.foreach(loc => + pbChangeLocationPartitionInfoBuilder.addPartition(PbSerDeUtils.toPbPartitionLocation(loc))) } builder.addPartitionInfo(pbChangeLocationPartitionInfoBuilder.build()) } From f6142859482c219e9137f2d4c8bdddc616407715 Mon Sep 17 00:00:00 2001 From: jiang13021 Date: Wed, 14 May 2025 21:25:17 +0800 Subject: [PATCH 2/3] Modify ChangePartitionManager and LifecycleManager --- .../celeborn/client/LocationManager.java | 632 ++++++++++-------- .../apache/celeborn/client/ReviveManager.java | 125 ++-- .../celeborn/client/ShuffleClientImpl.java | 268 ++++---- .../celeborn/client/write/DataPushQueue.java | 6 +- .../client/ChangePartitionManager.scala | 292 +++++--- .../celeborn/client/LifecycleManager.scala | 224 +++++-- .../client/PartitionLocationMonitor.scala | 107 +++ .../client/PartitionSplitTimeSlidingHub.scala | 50 ++ .../client/RequestLocationCallContext.scala | 32 +- .../common/protocol/ReviveRequest.java | 63 +- .../ConcurrentSkipListMapWithTracker.java | 79 +++ .../celeborn/common/util}/TimeSlidingHub.java | 9 +- .../apache/celeborn/common/CelebornConf.scala | 48 +- .../protocol/message/ControlMessages.scala | 9 +- .../common/util/CelebornHadoopUtils.scala | 1 - ...gePartitionManagerUpdateWorkersSuite.scala | 20 +- .../LifecycleManagerCommitFilesSuite.scala | 12 +- .../LifecycleManagerDestroySlotsSuite.scala | 9 +- .../tests/client/LifecycleManagerSuite.scala | 19 + .../PartitionLocationMonitorSuite.scala | 229 +++++++ .../congestcontrol/BufferStatusHub.java | 2 + .../congestcontrol/TestTimeSlidingHub.java | 2 + .../DynamicallySplitPartitionSuite.scala | 98 +++ 23 files changed, 1679 insertions(+), 657 deletions(-) create mode 100644 client/src/main/scala/org/apache/celeborn/client/PartitionLocationMonitor.scala create mode 100644 client/src/main/scala/org/apache/celeborn/client/PartitionSplitTimeSlidingHub.scala create mode 100644 common/src/main/java/org/apache/celeborn/common/util/ConcurrentSkipListMapWithTracker.java rename {worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol => common/src/main/java/org/apache/celeborn/common/util}/TimeSlidingHub.java (94%) create mode 100644 tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/PartitionLocationMonitorSuite.scala create mode 100644 worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/DynamicallySplitPartitionSuite.scala diff --git a/client/src/main/java/org/apache/celeborn/client/LocationManager.java b/client/src/main/java/org/apache/celeborn/client/LocationManager.java index d741dbe1514..f88f69a147d 100644 --- a/client/src/main/java/org/apache/celeborn/client/LocationManager.java +++ b/client/src/main/java/org/apache/celeborn/client/LocationManager.java @@ -1,14 +1,5 @@ package org.apache.celeborn.client; -import com.google.common.annotations.VisibleForTesting; -import org.apache.celeborn.common.protocol.PartitionLocation; -import org.apache.celeborn.common.protocol.ReviveRequest; -import org.apache.celeborn.common.protocol.message.StatusCode; -import org.apache.celeborn.common.util.JavaUtils; -import org.apache.celeborn.common.util.Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -21,322 +12,381 @@ import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.common.protocol.ReviveRequest; +import org.apache.celeborn.common.protocol.message.StatusCode; +import org.apache.celeborn.common.util.JavaUtils; + public class LocationManager { - private static final Logger logger = LoggerFactory.getLogger(LocationManager.class); - - private class PartitionLocationList { - List locations = new ArrayList<>(); - Set locationSet = new HashSet<>(); - // epoch id -> status - Map locationStatusCode = new HashMap<>(); - ReviveRequest latestReviveRequest = null; - int[] index = null; - - int used = 0; - int size = 0; - int maxEpoch = -1; - - final int shuffleId; - final int partitionId; - public PartitionLocationList(int shuffleId, int partitionId) { - this.shuffleId = shuffleId; - this.partitionId = partitionId; - } + private static final Logger logger = LoggerFactory.getLogger(LocationManager.class); - ReadWriteLock lock = new ReentrantReadWriteLock(); - Lock readLock = lock.readLock(); - Lock writeLock = lock.writeLock(); - - private void update(List newLocs) { - if (newLocs.isEmpty()) { - return; - } - newLocs.sort(Comparator.comparing(PartitionLocation::getEpoch)); - int newMaxEpoch = newLocs.get(newLocs.size() - 1).getEpoch(); - try { - writeLock.lock(); - if (newMaxEpoch <= maxEpoch) { - return; - } - int newSize = newLocs.size() + size - used; - ArrayList newLocations = new ArrayList<>(newSize); - for (PartitionLocation oldLoc : locations) { - if (locationStatusCode.remove(oldLoc.getEpoch()) != null) { - used--; - } else { - newLocations.add(oldLoc); - } - } - for (PartitionLocation l : newLocs) { - if (l.getEpoch() >= maxEpoch) { - newLocations.add(l); - } - } - size = newLocations.size(); - index = new int[size]; - for (int i = 0; i < size; i++) { - index[i] = i; - } - locations = newLocations; - locationSet.clear(); - locations.forEach(l -> locationSet.add(l.getEpoch())); - maxEpoch = Math.max(maxEpoch, newMaxEpoch); - logger.info("Location updated for shuffleId {}, partitionId {}, new locations: {}, maxEpoch: {}", shuffleId, partitionId, locationSet, maxEpoch); - if (latestReviveRequest != null && latestReviveRequest.clientMaxEpoch < maxEpoch) { - logger.debug("outdated latestReviveRequest {}", latestReviveRequest); - this.latestReviveRequest = null; - } - } finally { - writeLock.unlock(); - } - } + private class PartitionLocationList { + List locations = new ArrayList<>(); + Set locationSet = new HashSet<>(); + // epoch id -> status + Map locationStatusCode = new HashMap<>(); + ReviveRequest latestReviveRequest = null; + int[] index = null; - // return partitionLocation for specified mapId - // if allowSoftSplit = true, soft split location can be returned - // if liveOnly = false, non-living (soft split/hard split/push fail) location can be returned - private PartitionLocation nextLoc(int mapId, boolean allowSoftSplit, boolean liveOnly) { - try { - readLock.lock(); - int pos = mapId % size; - int idx = index[pos]; - while (locationStatusCode.get(locations.get(idx).getEpoch()) != null) { - if (allowSoftSplit && locationStatusCode.get(locations.get(idx).getEpoch()) == StatusCode.SOFT_SPLIT) { - break; - } - idx = (idx + 1) % size; - // all locations are checked - if (idx == index[pos]) { - break; - } - } - if (idx != index[pos]) { - index[pos] = idx; - } - if (locationStatusCode.get(locations.get(idx).getEpoch()) != null) { - if ((allowSoftSplit && locationStatusCode.get(locations.get(idx).getEpoch()) == StatusCode.SOFT_SPLIT) || - !liveOnly) { - return locations.get(idx); - } - return null; - } else { - return locations.get(idx); - } - } finally { - readLock.unlock(); - } - } + int used = 0; + int size = 0; + int maxEpoch = -1; - private void reviveBatch(int shuffleId, int partitionId, int mapId, int attemptId) { - ReviveRequest reviveRequest = null; - try { - writeLock.lock(); - reviveRequest = new ReviveRequest(shuffleId, mapId, attemptId, partitionId, null, StatusCode.URGENT_REVIVE, maxEpoch, true); - this.latestReviveRequest = reviveRequest; - logger.debug("in reviveBatch latestReviveRequest = {}", reviveRequest); - } finally { - writeLock.unlock(); - } - reviveManager.addRequest(reviveRequest); - } + final int shuffleId; + final int partitionId; - private void reportUnusableLocation(int shuffleId, int mapId, int attemptId, PartitionLocation loc, StatusCode reportedStatus) { - ReviveRequest reviveRequest = null; - try { - writeLock.lock(); - if (!locationSet.contains(loc.getEpoch())) { - return; - } - - StatusCode currentStatus = locationStatusCode.get(loc.getEpoch()); - if (currentStatus != reportedStatus) { - // allow normal/soft split to transition to hard split/push failure - if (currentStatus == null || currentStatus == StatusCode.SOFT_SPLIT) { - locationStatusCode.put(loc.getEpoch(), reportedStatus); - } - if (currentStatus == null) { - used++; - } - boolean urgent = ((used == size) && !hasActiveReviveRequest()); - reviveRequest = new ReviveRequest(shuffleId, mapId, attemptId, loc.getId(), loc, reportedStatus, maxEpoch, urgent); - if (urgent) { - this.latestReviveRequest = reviveRequest; - logger.debug("in reportUnusableLocation latestReviveRequest = {}", reviveRequest); - } - } - - } finally { - writeLock.unlock(); - } - if (reviveRequest != null) { - logger.info("Reported worker {}, partitionId {}, epoch {}, shuffle {} map {} attempt {} is unusable, status: {}, urgent: {}", - loc.hostAndPushPort(), loc.getId(), loc.getEpoch(), shuffleId, mapId, attemptId, reportedStatus.name(), reviveRequest.urgent); - reviveManager.addRequest(reviveRequest); - } - } + public PartitionLocationList(int shuffleId, int partitionId) { + this.shuffleId = shuffleId; + this.partitionId = partitionId; + } - private boolean newerPartitionLocationExists(int epoch) { - try { - readLock.lock(); - for (PartitionLocation loc : locations) { - if (locationStatusCode.get(loc.getEpoch()) != null && loc.getEpoch() > epoch) { - return true; - } - } - return false; - } finally { - readLock.unlock(); - } + ReadWriteLock lock = new ReentrantReadWriteLock(); + Lock readLock = lock.readLock(); + Lock writeLock = lock.writeLock(); + + private void update(List newLocs) { + if (newLocs.isEmpty()) { + return; + } + newLocs.sort(Comparator.comparing(PartitionLocation::getEpoch)); + int newMaxEpoch = newLocs.get(newLocs.size() - 1).getEpoch(); + try { + writeLock.lock(); + if (newMaxEpoch <= maxEpoch) { + return; } - - private boolean locationExists(int epoch) { - try { - readLock.lock(); - return locationSet.contains(epoch); - } finally { - readLock.unlock(); - } + int newSize = newLocs.size() + size - used; + ArrayList newLocations = new ArrayList<>(newSize); + for (PartitionLocation oldLoc : locations) { + if (locationStatusCode.remove(oldLoc.getEpoch()) != null) { + used--; + } else { + newLocations.add(oldLoc); + } } - - public boolean hasActiveReviveRequest() { - try { - readLock.lock(); - return latestReviveRequest != null && latestReviveRequest.reviveStatus == StatusCode.REVIVE_INITIALIZED.getValue(); - } finally { - readLock.unlock(); - } + for (PartitionLocation l : newLocs) { + if (l.getEpoch() >= maxEpoch) { + newLocations.add(l); + } } - - public StatusCode getLatestReviveStatus() { - try { - readLock.lock(); - if (latestReviveRequest == null) { - return StatusCode.REVIVE_INITIALIZED; - } else { - return StatusCode.fromValue(latestReviveRequest.reviveStatus); - } - } finally { - readLock.unlock(); - } + size = newLocations.size(); + index = new int[size]; + for (int i = 0; i < size; i++) { + index[i] = i; } + locations = newLocations; + locationSet.clear(); + locations.forEach(l -> locationSet.add(l.getEpoch())); + maxEpoch = Math.max(maxEpoch, newMaxEpoch); + logger.info( + "Location updated for shuffleId {}, partitionId {}, new locations: {}, maxEpoch: {}", + shuffleId, + partitionId, + locationSet, + maxEpoch); + if (latestReviveRequest != null && latestReviveRequest.clientMaxEpoch < maxEpoch) { + logger.debug("outdated latestReviveRequest {}", latestReviveRequest); + this.latestReviveRequest = null; + } + } finally { + writeLock.unlock(); + } } - final Map> reducePartitionMap = - JavaUtils.newConcurrentHashMap(); - - final ReviveManager reviveManager; - - final ShuffleClientImpl shuffleClient; - - public LocationManager(ShuffleClientImpl shuffleClient, ReviveManager reviveManager) { - this.shuffleClient = shuffleClient; - this.reviveManager = reviveManager; - } - - public void registerShuffleLocs(int shuffleId, ConcurrentHashMap> map) { - reducePartitionMap.computeIfAbsent(shuffleId, (id) -> { - ConcurrentHashMap locationMap = JavaUtils.newConcurrentHashMap(); - for (Map.Entry> e : map.entrySet()) { - int partitionId = e.getKey(); - List locs = e.getValue(); - PartitionLocationList list = new PartitionLocationList(shuffleId, partitionId); - list.update(locs); - locationMap.put(partitionId, list); - logger.debug("in registerShuffleLocs, shuffleId {}, partitionId {}", id, partitionId); - } - return locationMap; - }); + // return partitionLocation for specified mapId + // if allowSoftSplit = true, soft split location can be returned + // if liveOnly = false, non-living (soft split/hard split/push fail) location can be returned + private PartitionLocation nextLoc(int mapId, boolean allowSoftSplit, boolean liveOnly) { + try { + readLock.lock(); + int pos = mapId % size; + int idx = index[pos]; + while (locationStatusCode.get(locations.get(idx).getEpoch()) != null) { + if (allowSoftSplit + && locationStatusCode.get(locations.get(idx).getEpoch()) == StatusCode.SOFT_SPLIT) { + break; + } + idx = (idx + 1) % size; + // all locations are checked + if (idx == index[pos]) { + break; + } + } + if (idx != index[pos]) { + index[pos] = idx; + } + if (locationStatusCode.get(locations.get(idx).getEpoch()) != null) { + if ((allowSoftSplit + && locationStatusCode.get(locations.get(idx).getEpoch()) == StatusCode.SOFT_SPLIT) + || !liveOnly) { + return locations.get(idx); + } + return null; + } else { + return locations.get(idx); + } + } finally { + readLock.unlock(); + } } - public boolean registered(int shuffleId) { - return reducePartitionMap.containsKey(shuffleId); + private void reviveBatch(int shuffleId, int partitionId, int mapId, int attemptId) { + ReviveRequest reviveRequest = null; + try { + writeLock.lock(); + reviveRequest = + new ReviveRequest( + shuffleId, + mapId, + attemptId, + partitionId, + null, + StatusCode.URGENT_REVIVE, + maxEpoch, + true); + this.latestReviveRequest = reviveRequest; + logger.debug("in reviveBatch latestReviveRequest = {}", reviveRequest); + } finally { + writeLock.unlock(); + } + reviveManager.addRequest(reviveRequest); } - public boolean exists(int shuffleId, int partitionId) { - if (!registered(shuffleId)) { - throw new UnsupportedOperationException("unexpected! must ensure shuffle registered before checking partition exists "); + private void reportUnusableLocation( + int shuffleId, int mapId, int attemptId, PartitionLocation loc, StatusCode reportedStatus) { + ReviveRequest reviveRequest = null; + try { + writeLock.lock(); + if (!locationSet.contains(loc.getEpoch())) { + return; } - return reducePartitionMap.get(shuffleId).containsKey(partitionId); - } - - public StatusCode getReviveStatus(int shuffleId, int partitionId) { - PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); - return locationList.getLatestReviveStatus(); - } - public PartitionLocation getLocationOrReviveAsync(int shuffleId, int partitionId, int mapId, int attemptId, boolean doRevive, boolean liveOnly) { - PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); - // firstly, try to find a live partition location - PartitionLocation loc = locationList.nextLoc(mapId, false, true); - if (loc == null) { - if (doRevive && !locationList.hasActiveReviveRequest()) { - locationList.reviveBatch(shuffleId, partitionId, mapId, attemptId); - } else if (doRevive && locationList.hasActiveReviveRequest()) { - logger.debug("in getLocationOrReviveAsync, do nothing, current latestReviveRequest is {}", locationList.latestReviveRequest); - } - // can't get a live partition location, then try to find a location in soft split status - // if liveOnly = false, hard split/push fail location can be returned - loc = locationList.nextLoc(mapId, true, liveOnly); + StatusCode currentStatus = locationStatusCode.get(loc.getEpoch()); + if (currentStatus != reportedStatus) { + // allow normal/soft split to transition to hard split/push failure + if (currentStatus == null || currentStatus == StatusCode.SOFT_SPLIT) { + locationStatusCode.put(loc.getEpoch(), reportedStatus); + } + if (currentStatus == null) { + used++; + } + boolean urgent = ((used == size) && !hasActiveReviveRequest()); + reviveRequest = + new ReviveRequest( + shuffleId, mapId, attemptId, loc.getId(), loc, reportedStatus, maxEpoch, urgent); + if (urgent) { + this.latestReviveRequest = reviveRequest; + logger.debug("in reportUnusableLocation latestReviveRequest = {}", reviveRequest); + } } - return loc; + + } finally { + writeLock.unlock(); + } + if (reviveRequest != null) { + logger.info( + "Reported worker {}, partitionId {}, epoch {}, shuffle {} map {} attempt {} is unusable, status: {}, urgent: {}", + loc.hostAndPushPort(), + loc.getId(), + loc.getEpoch(), + shuffleId, + mapId, + attemptId, + reportedStatus.name(), + reviveRequest.urgent); + reviveManager.addRequest(reviveRequest); + } } - public boolean reviveSync(int shuffleId, int partitionId, int mapId, int attemptId, StatusCode cause) { - Set mapIds = new HashSet<>(); - mapIds.add(mapId); - List requests = new ArrayList<>(); - ReviveRequest request = new ReviveRequest(shuffleId, mapId, attemptId, partitionId, null, cause, 0, true); - requests.add(request); - Map results = shuffleClient.reviveBatch(shuffleId, mapIds, requests, true); - - if (shuffleClient.mapperEnded(shuffleId, mapId)) { - logger.debug( - "Revive success, but the mapper ended for shuffle {} map {} attempt {} partition {}, just return true(Assume revive successfully).", - shuffleId, - mapId, - attemptId, - partitionId); + private boolean newerPartitionLocationExists(int epoch) { + try { + readLock.lock(); + for (PartitionLocation loc : locations) { + if (locationStatusCode.get(loc.getEpoch()) != null && loc.getEpoch() > epoch) { return true; - } else { - return results != null - && results.containsKey(partitionId) - && results.get(partitionId) == StatusCode.SUCCESS.getValue(); + } } + return false; + } finally { + readLock.unlock(); + } } - public void updateLocation(int shuffleId, int partitionId, List newLocations) { - PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); - locationList.update(newLocations); + private boolean locationExists(int epoch) { + try { + readLock.lock(); + return locationSet.contains(epoch); + } finally { + readLock.unlock(); + } } - public void reportUnusableLocation(int shuffleId, int mapId, int attemptId, PartitionLocation reportedPartition, StatusCode reportedStatus) { - int partitionId = reportedPartition.getId(); - PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); - locationList.reportUnusableLocation(shuffleId, mapId, attemptId, reportedPartition, reportedStatus); + public boolean hasActiveReviveRequest() { + try { + readLock.lock(); + return latestReviveRequest != null + && latestReviveRequest.reviveStatus == StatusCode.REVIVE_INITIALIZED.getValue(); + } finally { + readLock.unlock(); + } } - public boolean newerPartitionLocationExists(int shuffleId, int partitionId, int epoch) { - PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); - return locationList.newerPartitionLocationExists(epoch); + public StatusCode getLatestReviveStatus() { + try { + readLock.lock(); + if (latestReviveRequest == null) { + return StatusCode.REVIVE_INITIALIZED; + } else { + return StatusCode.fromValue(latestReviveRequest.reviveStatus); + } + } finally { + readLock.unlock(); + } } + } + + final Map> reducePartitionMap = + JavaUtils.newConcurrentHashMap(); + + final ReviveManager reviveManager; + + final ShuffleClientImpl shuffleClient; + + public LocationManager(ShuffleClientImpl shuffleClient, ReviveManager reviveManager) { + this.shuffleClient = shuffleClient; + this.reviveManager = reviveManager; + } + + public void registerShuffleLocs( + int shuffleId, ConcurrentHashMap> map) { + reducePartitionMap.computeIfAbsent( + shuffleId, + (id) -> { + ConcurrentHashMap locationMap = + JavaUtils.newConcurrentHashMap(); + for (Map.Entry> e : map.entrySet()) { + int partitionId = e.getKey(); + List locs = e.getValue(); + PartitionLocationList list = new PartitionLocationList(shuffleId, partitionId); + list.update(locs); + locationMap.put(partitionId, list); + logger.debug("in registerShuffleLocs, shuffleId {}, partitionId {}", id, partitionId); + } + return locationMap; + }); + } - public boolean locationExists(int shuffleId, int partitionId, int epoch) { - PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); - return locationList.locationExists(epoch); - } + public boolean registered(int shuffleId) { + return reducePartitionMap.containsKey(shuffleId); + } - public boolean hasActiveReviveRequest(int shuffleId, int partitionId) { - PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); - return locationList.hasActiveReviveRequest(); + public boolean exists(int shuffleId, int partitionId) { + if (!registered(shuffleId)) { + throw new UnsupportedOperationException( + "unexpected! must ensure shuffle registered before checking partition exists "); } - - public void removeShuffle(int shuffleId) { - reducePartitionMap.remove(shuffleId); + return reducePartitionMap.get(shuffleId).containsKey(partitionId); + } + + public StatusCode getReviveStatus(int shuffleId, int partitionId) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + return locationList.getLatestReviveStatus(); + } + + public PartitionLocation getLocationOrReviveAsync( + int shuffleId, + int partitionId, + int mapId, + int attemptId, + boolean doRevive, + boolean liveOnly) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + // firstly, try to find a live partition location + PartitionLocation loc = locationList.nextLoc(mapId, false, true); + if (loc == null) { + if (doRevive && !locationList.hasActiveReviveRequest()) { + locationList.reviveBatch(shuffleId, partitionId, mapId, attemptId); + } else if (doRevive && locationList.hasActiveReviveRequest()) { + logger.debug( + "in getLocationOrReviveAsync, do nothing, current latestReviveRequest is {}", + locationList.latestReviveRequest); + } + // can't get a live partition location, then try to find a location in soft split status + // if liveOnly = false, hard split/push fail location can be returned + loc = locationList.nextLoc(mapId, true, liveOnly); } - - @VisibleForTesting - public StatusCode getLocationStatus(int shuffleId, int partitionId, int epochId) { - PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); - return locationList.locationStatusCode.getOrDefault(epochId, StatusCode.SUCCESS); + return loc; + } + + public boolean reviveSync( + int shuffleId, int partitionId, int mapId, int attemptId, StatusCode cause) { + Set mapIds = new HashSet<>(); + mapIds.add(mapId); + List requests = new ArrayList<>(); + ReviveRequest request = + new ReviveRequest(shuffleId, mapId, attemptId, partitionId, null, cause, 0, true); + requests.add(request); + Map results = shuffleClient.reviveBatch(shuffleId, mapIds, requests, true); + + if (shuffleClient.mapperEnded(shuffleId, mapId)) { + logger.debug( + "Revive success, but the mapper ended for shuffle {} map {} attempt {} partition {}, just return true(Assume revive successfully).", + shuffleId, + mapId, + attemptId, + partitionId); + return true; + } else { + return results != null + && results.containsKey(partitionId) + && results.get(partitionId) == StatusCode.SUCCESS.getValue(); } + } + + public void updateLocation(int shuffleId, int partitionId, List newLocations) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + locationList.update(newLocations); + } + + public void reportUnusableLocation( + int shuffleId, + int mapId, + int attemptId, + PartitionLocation reportedPartition, + StatusCode reportedStatus) { + int partitionId = reportedPartition.getId(); + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + locationList.reportUnusableLocation( + shuffleId, mapId, attemptId, reportedPartition, reportedStatus); + } + + public boolean newerPartitionLocationExists(int shuffleId, int partitionId, int epoch) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + return locationList.newerPartitionLocationExists(epoch); + } + + public boolean locationExists(int shuffleId, int partitionId, int epoch) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + return locationList.locationExists(epoch); + } + + public boolean hasActiveReviveRequest(int shuffleId, int partitionId) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + return locationList.hasActiveReviveRequest(); + } + + public void removeShuffle(int shuffleId) { + reducePartitionMap.remove(shuffleId); + } + + @VisibleForTesting + public StatusCode getLocationStatus(int shuffleId, int partitionId, int epochId) { + PartitionLocationList locationList = reducePartitionMap.get(shuffleId).get(partitionId); + return locationList.locationStatusCode.getOrDefault(epochId, StatusCode.SUCCESS); + } } diff --git a/client/src/main/java/org/apache/celeborn/client/ReviveManager.java b/client/src/main/java/org/apache/celeborn/client/ReviveManager.java index 888e4d12794..81ae4fabf65 100644 --- a/client/src/main/java/org/apache/celeborn/client/ReviveManager.java +++ b/client/src/main/java/org/apache/celeborn/client/ReviveManager.java @@ -45,11 +45,11 @@ class ReviveManager { private final int batchSize; ShuffleClientImpl shuffleClient; private ScheduledExecutorService batchReviveRequestScheduler = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("batch-revive-scheduler"); + ThreadUtils.newDaemonSingleThreadScheduledExecutor("batch-revive-scheduler"); private ThreadPoolExecutor batchReviveRequestHandler = - ThreadUtils.newDaemonFixedThreadPool(2, "batch-revive-handler"); + ThreadUtils.newDaemonFixedThreadPool(2, "batch-revive-handler"); private ThreadPoolExecutor batchReportRequestHandler = - ThreadUtils.newDaemonFixedThreadPool(2, "batch-report-handler"); + ThreadUtils.newDaemonFixedThreadPool(2, "batch-report-handler"); public ReviveManager(ShuffleClientImpl shuffleClient, CelebornConf conf) { this.shuffleClient = shuffleClient; @@ -57,39 +57,39 @@ public ReviveManager(ShuffleClientImpl shuffleClient, CelebornConf conf) { this.batchSize = conf.clientPushReviveBatchSize(); batchReviveRequestScheduler.scheduleWithFixedDelay( - () -> { - try { - Map> urgentMap = new HashMap<>(); - Map> nonUrgentMap = new HashMap<>(); - do { - ArrayList batchRequests = new ArrayList<>(); - requestQueue.drainTo(batchRequests, batchSize); - for (ReviveRequest req : batchRequests) { - Set set = null; - if (req.urgent) { - set = urgentMap.computeIfAbsent(req.shuffleId, id -> new HashSet<>()); - } else { - set = nonUrgentMap.computeIfAbsent(req.shuffleId, id -> new HashSet<>()); - } - set.add(req); - } - if (!urgentMap.isEmpty()) { - reviveInternal(urgentMap, true); - } - if (!nonUrgentMap.isEmpty()) { - reviveInternal(nonUrgentMap, false); - } - // break the loop if remaining requests is less than half of - // `celeborn.client.push.revive.batchSize` - } while (requestQueue.size() > batchSize / 2); - } catch (Throwable e) { - logger.error("Exception when batchRevive: ", e); - throw e; + () -> { + try { + Map> urgentMap = new HashMap<>(); + Map> nonUrgentMap = new HashMap<>(); + do { + ArrayList batchRequests = new ArrayList<>(); + requestQueue.drainTo(batchRequests, batchSize); + for (ReviveRequest req : batchRequests) { + Set set = null; + if (req.urgent) { + set = urgentMap.computeIfAbsent(req.shuffleId, id -> new HashSet<>()); + } else { + set = nonUrgentMap.computeIfAbsent(req.shuffleId, id -> new HashSet<>()); + } + set.add(req); + } + if (!urgentMap.isEmpty()) { + reviveInternal(urgentMap, true); + } + if (!nonUrgentMap.isEmpty()) { + reviveInternal(nonUrgentMap, false); } - }, - interval, - interval, - TimeUnit.MILLISECONDS); + // break the loop if remaining requests is less than half of + // `celeborn.client.push.revive.batchSize` + } while (requestQueue.size() > batchSize / 2); + } catch (Throwable e) { + logger.error("Exception when batchRevive: ", e); + throw e; + } + }, + interval, + interval, + TimeUnit.MILLISECONDS); } public void reviveInternal(Map> shuffleMap, boolean urgent) { @@ -111,14 +111,16 @@ public void processRequests(int shuffleId, Collection requests, b Iterator iter = requests.iterator(); while (iter.hasNext()) { ReviveRequest req = iter.next(); - if ((urgent && shuffleClient.newerPartitionLocationExists(shuffleId, req.partitionId, req.clientMaxEpoch)) - || shuffleClient.mapperEnded(shuffleId, req.mapId)) { + if ((urgent + && shuffleClient.newerPartitionLocationExists( + shuffleId, req.partitionId, req.clientMaxEpoch)) + || shuffleClient.mapperEnded(shuffleId, req.mapId)) { req.reviveStatus = StatusCode.SUCCESS.getValue(); } else { filteredRequests.add(req); mapIds.add(req.mapId); if (!requestsToSend.containsKey(req.partitionId) - || requestsToSend.get(req.partitionId).clientMaxEpoch < req.clientMaxEpoch) { + || requestsToSend.get(req.partitionId).clientMaxEpoch < req.clientMaxEpoch) { requestsToSend.put(req.partitionId, req); } } @@ -126,33 +128,34 @@ public void processRequests(int shuffleId, Collection requests, b ThreadPoolExecutor handler = urgent ? batchReviveRequestHandler : batchReportRequestHandler; if (!requestsToSend.isEmpty()) { - handler.submit(() -> { - try { - // Call reviveBatch. Return null means Exception caught or - // SHUFFLE_NOT_REGISTERED - //Do not use WriterTracerHere because traceInfo is set afterward - long reviveStartTime = System.nanoTime(); - Map results = + handler.submit( + () -> { + try { + // Call reviveBatch. Return null means Exception caught or + // SHUFFLE_NOT_REGISTERED + // Do not use WriterTracerHere because traceInfo is set afterward + long reviveStartTime = System.nanoTime(); + Map results = shuffleClient.reviveBatch(shuffleId, mapIds, requestsToSend.values(), urgent); - long reviveCostTime = System.nanoTime() - reviveStartTime; - if (results == null) { - for (ReviveRequest req : filteredRequests) { - req.reviveStatus = StatusCode.REVIVE_FAILED.getValue(); - } - } else { - for (ReviveRequest req : filteredRequests) { - if (shuffleClient.mapperEnded(shuffleId, req.mapId)) { - req.reviveStatus = StatusCode.SUCCESS.getValue(); + long reviveCostTime = System.nanoTime() - reviveStartTime; + if (results == null) { + for (ReviveRequest req : filteredRequests) { + req.reviveStatus = StatusCode.REVIVE_FAILED.getValue(); + } } else { - req.reviveStatus = results.get(req.partitionId); + for (ReviveRequest req : filteredRequests) { + if (shuffleClient.mapperEnded(shuffleId, req.mapId)) { + req.reviveStatus = StatusCode.SUCCESS.getValue(); + } else { + req.reviveStatus = results.get(req.partitionId); + } + } } + } catch (Throwable e) { + logger.error("Exception when processRequests: ", e); + throw e; } - } - } catch (Throwable e) { - logger.error("Exception when processRequests: ", e); - throw e; - } - }); + }); } } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 4d2bfac125d..f1041f7422e 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -301,24 +301,33 @@ private void submitRetryPushData( int remainReviveTimes, long dueTime) { long reviveWaitTime = dueTime - System.currentTimeMillis(); - long resubmitWaitTime = conf.clientRpcRequestPartitionLocationAskTimeout() - .duration() - .toMillis() - reviveWaitTime; + long resubmitWaitTime = + conf.clientRpcRequestPartitionLocationAskTimeout().duration().toMillis() - reviveWaitTime; final long delta = 50; long accumulatedTime = 0; - PartitionLocation loc = locationManager.getLocationOrReviveAsync(shuffleId, partitionId, mapId, attemptId, true, true); + PartitionLocation loc = + locationManager.getLocationOrReviveAsync( + shuffleId, partitionId, mapId, attemptId, true, true); while (loc == null && accumulatedTime <= reviveWaitTime) { try { Thread.sleep(delta); accumulatedTime += delta; - boolean hasActiveReviveRequest = locationManager.hasActiveReviveRequest(shuffleId, partitionId); - loc = locationManager.getLocationOrReviveAsync(shuffleId, partitionId, mapId, attemptId, false, true); + boolean hasActiveReviveRequest = + locationManager.hasActiveReviveRequest(shuffleId, partitionId); + loc = + locationManager.getLocationOrReviveAsync( + shuffleId, partitionId, mapId, attemptId, false, true); if (!hasActiveReviveRequest) { if (loc == null) { - logger.warn("There is no active revive request, however, a new location has not yet been assigned for" + - " shuffle {} map {} attempt {} partition {} batch {} ", - shuffleId, mapId, attemptId, partitionId, batchId); + logger.warn( + "There is no active revive request, however, a new location has not yet been assigned for" + + " shuffle {} map {} attempt {} partition {} batch {} ", + shuffleId, + mapId, + attemptId, + partitionId, + batchId); } break; } @@ -340,10 +349,12 @@ private void submitRetryPushData( } else if (loc == null) { StatusCode reviveStatus = locationManager.getReviveStatus(shuffleId, partitionId); pushDataRpcResponseCallback.onFailure( - new CelebornIOException( - cause + " then revive but failed, revive status " - + reviveStatus - + ", old location: " + oldLoc)); + new CelebornIOException( + cause + + " then revive but failed, revive status " + + reviveStatus + + ", old location: " + + oldLoc)); } else { logger.info( "Revive for push data success, new location for shuffle {} map {} attempt {} partition {} batch {} is location {}.", @@ -393,20 +404,29 @@ private void submitRetryPushData( } private void submitRetryPushMergedData( - PushState pushState, - int shuffleId, - int mapId, - int attemptId, - ArrayList batches, - StatusCode cause, - Integer oldGroupedBatchId, - int remainReviveTimes, - long reviveResponseDueTime) { + PushState pushState, + int shuffleId, + int mapId, + int attemptId, + ArrayList batches, + StatusCode cause, + Integer oldGroupedBatchId, + int remainReviveTimes, + long reviveResponseDueTime) { List causeList = new ArrayList<>(); for (int i = 0; i < batches.size(); i++) { causeList.add(cause); } - submitRetryPushMergedData(pushState, shuffleId, mapId, attemptId, batches, causeList, oldGroupedBatchId, remainReviveTimes, reviveResponseDueTime); + submitRetryPushMergedData( + pushState, + shuffleId, + mapId, + attemptId, + batches, + causeList, + oldGroupedBatchId, + remainReviveTimes, + reviveResponseDueTime); } private void submitRetryPushMergedData( @@ -420,9 +440,8 @@ private void submitRetryPushMergedData( int remainReviveTimes, long reviveResponseDueTime) { long reviveWaitTime = reviveResponseDueTime - System.currentTimeMillis(); - long resubmitWaitTime = conf.clientRpcRequestPartitionLocationAskTimeout() - .duration() - .toMillis() - reviveWaitTime; + long resubmitWaitTime = + conf.clientRpcRequestPartitionLocationAskTimeout().duration().toMillis() - reviveWaitTime; HashMap, DataBatches> newDataBatchesMap = new HashMap<>(); ArrayList reviveFailedBatchesMap = new ArrayList<>(); @@ -435,21 +454,19 @@ private void submitRetryPushMergedData( StatusCode cause = causeList.get(i); oldAddressPair = batch.loc.hostAndPushPort(); dataBatchReviveInfos.append( - String.format( - "(batchId=%d, partitionId=%d, epochId=%d, cause=%s)", - batch.batchId, - batch.loc.getId(), - batch.loc.getEpoch(), - cause)); + String.format( + "(batchId=%d, partitionId=%d, epochId=%d, cause=%s)", + batch.batchId, batch.loc.getId(), batch.loc.getEpoch(), cause)); } - logger.error("Push merged data to {} failed for shuffle {} map {} attempt {} groupedBatch {}, split batches {}, remain revive times {}.", - oldAddressPair, - shuffleId, - mapId, - attemptId, - oldGroupedBatchId, - dataBatchReviveInfos, - remainReviveTimes); + logger.error( + "Push merged data to {} failed for shuffle {} map {} attempt {} groupedBatch {}, split batches {}, remain revive times {}.", + oldAddressPair, + shuffleId, + mapId, + attemptId, + oldGroupedBatchId, + dataBatchReviveInfos, + remainReviveTimes); final long delta = 50; long accumulatedTime = 0; @@ -459,50 +476,58 @@ private void submitRetryPushMergedData( while (index < batches.size() && accumulatedTime <= reviveWaitTime) { DataBatches.DataBatch batch = batches.get(index); int partitionId = batch.loc.getId(); - PartitionLocation loc = locationManager.getLocationOrReviveAsync(shuffleId, partitionId, mapId, attemptId, doRevive, true); + PartitionLocation loc = + locationManager.getLocationOrReviveAsync( + shuffleId, partitionId, mapId, attemptId, doRevive, true); doRevive = false; if (!locationManager.hasActiveReviveRequest(shuffleId, partitionId)) { if (mapperEnded(shuffleId, mapId)) { logger.debug( - "Revive for push merged data success, but the mapper already ended for shuffle {} map {} attempt {} partition {} batch {}.", - shuffleId, - mapId, - attemptId, - partitionId, - oldGroupedBatchId); + "Revive for push merged data success, but the mapper already ended for shuffle {} map {} attempt {} partition {} batch {}.", + shuffleId, + mapId, + attemptId, + partitionId, + oldGroupedBatchId); } else if (loc != null) { logger.info( - "Revive for push merged data success, new location for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {} is location {}.", - shuffleId, - mapId, - attemptId, - partitionId, - oldGroupedBatchId, - batch.batchId, - loc); - DataBatches newDataBatches = newDataBatchesMap.computeIfAbsent(genAddressPair(loc), (s) -> new DataBatches()); + "Revive for push merged data success, new location for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {} is location {}.", + shuffleId, + mapId, + attemptId, + partitionId, + oldGroupedBatchId, + batch.batchId, + loc); + DataBatches newDataBatches = + newDataBatchesMap.computeIfAbsent(genAddressPair(loc), (s) -> new DataBatches()); newDataBatches.addDataBatch(loc, batch.batchId, batch.body); } else { - logger.warn("There is no active revive request, however, a new location has not yet been assigned for" + - " shuffle {} map {} attempt {} partition {} groupedBatch {}", - shuffleId, mapId, attemptId, partitionId, oldGroupedBatchId); + logger.warn( + "There is no active revive request, however, a new location has not yet been assigned for" + + " shuffle {} map {} attempt {} partition {} groupedBatch {}", + shuffleId, + mapId, + attemptId, + partitionId, + oldGroupedBatchId); if (remainReviveTimes > 0) { reviveFailedBatchesMap.add(batch); reviveFailedCauseList.add(causeList.get(index)); } else { String errorMsg = - String.format( - "Revive failed while pushing merged for shuffle %d map %d attempt %d partition %d batch %d location %s.", - shuffleId, mapId, attemptId, partitionId, oldGroupedBatchId, batch.loc); + String.format( + "Revive failed while pushing merged for shuffle %d map %d attempt %d partition %d batch %d location %s.", + shuffleId, mapId, attemptId, partitionId, oldGroupedBatchId, batch.loc); StatusCode reviveStatus = locationManager.getReviveStatus(shuffleId, partitionId); pushState.exception.compareAndSet( - null, + null, + new CelebornIOException( + errorMsg, new CelebornIOException( - errorMsg, - new CelebornIOException( - causeList.get(index) - + " then revive but failed, revive status " - + reviveStatus))); + causeList.get(index) + + " then revive but failed, revive status " + + reviveStatus))); return; } } @@ -527,18 +552,18 @@ private void submitRetryPushMergedData( } else { int partitionId = batch.loc.getId(); String errorMsg = - String.format( - "Revive failed while pushing merged for shuffle %d map %d attempt %d partition %d batch %d location %s.", - shuffleId, mapId, attemptId, partitionId, oldGroupedBatchId, batch.loc); + String.format( + "Revive failed while pushing merged for shuffle %d map %d attempt %d partition %d batch %d location %s.", + shuffleId, mapId, attemptId, partitionId, oldGroupedBatchId, batch.loc); StatusCode reviveStatus = locationManager.getReviveStatus(shuffleId, partitionId); pushState.exception.compareAndSet( - null, + null, + new CelebornIOException( + errorMsg, new CelebornIOException( - errorMsg, - new CelebornIOException( - causeList.get(index) - + " then revive but failed, revive status " - + reviveStatus))); + causeList.get(index) + + " then revive but failed, revive status " + + reviveStatus))); return; } } @@ -644,23 +669,23 @@ public ConcurrentHashMap getPartitionLocation( int shuffleId, int numMappers, int numPartitions) throws CelebornIOException { // TODO only UT related usages, fix later return null; -// try { -// return reducePartitionMap.computeIfAbsent( -// shuffleId, -// (id) -> { -// try { -// return registerShuffle(shuffleId, numMappers, numPartitions); -// } catch (CelebornIOException e) { -// throw new RuntimeException(e); -// } -// }); -// } catch (RuntimeException e) { -// if (e.getCause() instanceof CelebornIOException) { -// throw (CelebornIOException) e.getCause(); -// } else { -// throw e; -// } -// } + // try { + // return reducePartitionMap.computeIfAbsent( + // shuffleId, + // (id) -> { + // try { + // return registerShuffle(shuffleId, numMappers, numPartitions); + // } catch (CelebornIOException e) { + // throw new RuntimeException(e); + // } + // }); + // } catch (RuntimeException e) { + // if (e.getCause() instanceof CelebornIOException) { + // throw (CelebornIOException) e.getCause(); + // } else { + // throw e; + // } + // } } @Override @@ -668,10 +693,10 @@ public boolean ensureRegistered(int shuffleId, int numMappers, int numPartitions if (!locationManager.registered(shuffleId)) { try { ConcurrentHashMap> map = - registerShuffle(shuffleId, numMappers, numPartitions); + registerShuffle(shuffleId, numMappers, numPartitions); locationManager.registerShuffleLocs(shuffleId, map); } catch (CelebornIOException e) { - //TODO log exception + // TODO log exception return false; } } @@ -750,17 +775,19 @@ private ConcurrentHashMap> registerShuffleInter PbRegisterShuffleResponse response = callable.call(); StatusCode respStatus = StatusCode.fromValue(response.getStatus()); if (StatusCode.SUCCESS.equals(respStatus)) { - ConcurrentHashMap> result = JavaUtils.newConcurrentHashMap(); + ConcurrentHashMap> result = + JavaUtils.newConcurrentHashMap(); for (int i = 0; i < response.getPartitionLocationsList().size(); i++) { Tuple2, List> locations = - PbSerDeUtils.fromPbPackedPartitionLocationsPair( + PbSerDeUtils.fromPbPackedPartitionLocationsPair( response.getPackedPartitionLocationsPair()); for (PartitionLocation location : locations._1) { pushExcludedWorkers.remove(location.hostAndPushPort()); if (location.hasPeer()) { pushExcludedWorkers.remove(location.getPeer().hostAndPushPort()); } - List list = result.computeIfAbsent(location.getId(), x -> new ArrayList<>()); + List list = + result.computeIfAbsent(location.getId(), x -> new ArrayList<>()); list.add(location); result.put(location.getId(), list); } @@ -841,8 +868,7 @@ protected void limitZeroInFlight(String mapKey, PushState pushState) throws IOEx * @param epoch The epoch of revive. * @return whether newer partition location exists in local cache. */ - boolean newerPartitionLocationExists( - int shuffleId, int partitionId, int epoch) { + boolean newerPartitionLocationExists(int shuffleId, int partitionId, int epoch) { return locationManager.newerPartitionLocationExists(shuffleId, partitionId, epoch); } @@ -880,12 +906,14 @@ Map reviveBatch( try { PbChangeLocationResponse response; if (urgent) { - response = lifecycleManagerRef.askSync( + response = + lifecycleManagerRef.askSync( Revive$.MODULE$.apply(shuffleId, mapIds, requests), conf.clientRpcRequestPartitionLocationAskTimeout(), ClassTag$.MODULE$.apply(PbChangeLocationResponse.class)); } else { - response = lifecycleManagerRef.askSync( + response = + lifecycleManagerRef.askSync( PartitionSplitReport$.MODULE$.apply(shuffleId, mapIds, requests), conf.clientRpcRequestPartitionLocationAskTimeout(), ClassTag$.MODULE$.apply(PbChangeLocationResponse.class)); @@ -1004,13 +1032,13 @@ public int pushOrMergeData( // register shuffle will return an empty location map, client need revive for a new location. if (!locationManager.exists(shuffleId, partitionId)) { if (!locationManager.reviveSync( - shuffleId, - partitionId, - mapId, - attemptId, - StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_PRIMARY)) { + shuffleId, + partitionId, + mapId, + attemptId, + StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_PRIMARY)) { throw new CelebornIOException( - String.format("Revive for shuffle %s partition %d failed.", shuffleId, partitionId)); + String.format("Revive for shuffle %s partition %d failed.", shuffleId, partitionId)); } } @@ -1028,7 +1056,9 @@ public int pushOrMergeData( return 0; } - PartitionLocation loc = locationManager.getLocationOrReviveAsync(shuffleId, partitionId, mapId, attemptId, false, false); + PartitionLocation loc = + locationManager.getLocationOrReviveAsync( + shuffleId, partitionId, mapId, attemptId, false, false); if (loc == null) { throw new CelebornIOException( String.format( @@ -1124,7 +1154,8 @@ public void onSuccess(ByteBuffer response) { attemptId, partitionId, nextBatchId); - locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, latest, StatusCode.SOFT_SPLIT); + locationManager.reportUnusableLocation( + shuffleId, mapId, attemptId, latest, StatusCode.SOFT_SPLIT); pushState.onSuccess(latest.hostAndPushPort()); pushState.removeBatch(nextBatchId, latest.hostAndPushPort()); callback.onSuccess(response); @@ -1142,7 +1173,8 @@ public void onSuccess(ByteBuffer response) { pushState.addFailedBatch( latest.getUniqueId(), new PushFailedBatch(mapId, attemptId, nextBatchId)); } - locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, latest, StatusCode.HARD_SPLIT); + locationManager.reportUnusableLocation( + shuffleId, mapId, attemptId, latest, StatusCode.HARD_SPLIT); long dueTime = System.currentTimeMillis() + conf.clientRpcRequestPartitionLocationAskTimeout() @@ -1530,9 +1562,11 @@ public void onSuccess(ByteBuffer response) { partitionIds[partitionIndex], StatusCode.fromValue(statusCodeList.get(i).byteValue()))); if (statusCodeList.get(i) == StatusCode.SOFT_SPLIT.getValue()) { - locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, currentBatch.loc, StatusCode.SOFT_SPLIT); + locationManager.reportUnusableLocation( + shuffleId, mapId, attemptId, currentBatch.loc, StatusCode.SOFT_SPLIT); } else { - locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, currentBatch.loc, StatusCode.HARD_SPLIT); + locationManager.reportUnusableLocation( + shuffleId, mapId, attemptId, currentBatch.loc, StatusCode.HARD_SPLIT); batchesNeedResubmit.add(batches.get(partitionIndex)); causeList.add(StatusCode.HARD_SPLIT); } @@ -1553,7 +1587,8 @@ public void onSuccess(ByteBuffer response) { batchesNeedResubmit = batches; for (int i = 0; i < numBatches; i++) { causeList.add(StatusCode.HARD_SPLIT); - locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, batches.get(i).loc, StatusCode.HARD_SPLIT); + locationManager.reportUnusableLocation( + shuffleId, mapId, attemptId, batches.get(i).loc, StatusCode.HARD_SPLIT); } logger.info( "Push merged data to {} hard split required for shuffle {} map {} attempt {} partition {} groupedBatch {} batch {}.", @@ -1663,7 +1698,8 @@ public void onFailure(Throwable e) { e); if (!mapperEnded(shuffleId, mapId)) { for (DataBatches.DataBatch batch : batches) { - locationManager.reportUnusableLocation(shuffleId, mapId, attemptId, batch.loc, cause); + locationManager.reportUnusableLocation( + shuffleId, mapId, attemptId, batch.loc, cause); } pushDataRetryPool.submit( () -> diff --git a/client/src/main/java/org/apache/celeborn/client/write/DataPushQueue.java b/client/src/main/java/org/apache/celeborn/client/write/DataPushQueue.java index 7df04a9158d..461aba24032 100644 --- a/client/src/main/java/org/apache/celeborn/client/write/DataPushQueue.java +++ b/client/src/main/java/org/apache/celeborn/client/write/DataPushQueue.java @@ -23,10 +23,10 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import org.apache.celeborn.client.LocationManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.celeborn.client.LocationManager; import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.PartitionLocation; @@ -103,7 +103,9 @@ public ArrayList takePushTasks() throws IOException, InterruptedExcept while (iterator.hasNext()) { PushTask task = iterator.next(); int partitionId = task.getPartitionId(); - PartitionLocation loc = locationManager.getLocationOrReviveAsync(shuffleId, partitionId, mapId, attemptId, true, false); + PartitionLocation loc = + locationManager.getLocationOrReviveAsync( + shuffleId, partitionId, mapId, attemptId, true, false); // According to CELEBORN-560, call rerun task and speculative task after LifecycleManager // handle StageEnd will return empty PartitionLocation map, here loc can be null if (loc != null) { diff --git a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala index 5b8e4ce6602..3ebdfa8df5b 100644 --- a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala @@ -36,7 +36,8 @@ case class ChangePartitionRequest( context: RequestLocationCallContext, shuffleId: Int, partitionId: Int, - epoch: Int, + clientMaxEpoch: Int, + targetEpoch: Int, oldPartition: PartitionLocation, causes: Option[StatusCode]) @@ -54,9 +55,12 @@ class ChangePartitionManager( private val locks = JavaUtils.newConcurrentHashMap[Int, Array[AnyRef]]() private val lockBucketSize = conf.batchHandleChangePartitionBuckets - // shuffleId -> set of partition id + // shuffleId -> (partitionId -> maxTargetEpoch) + // maxTargetEpoch should be java.lang.Integer because we will remove + // element by setting the value to null which is not supported by Scala.Int private val inBatchPartitions = - JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap.KeySetView[Int, java.lang.Boolean]]() + JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Integer, Integer]]() + private val asyncSplitPartitionEnabled = conf.asyncSplitPartitionEnabled private val batchHandleChangePartitionEnabled = conf.batchHandleChangePartitionEnabled private val batchHandleChangePartitionExecutors = ThreadUtils.newDaemonCachedThreadPool( @@ -90,21 +94,23 @@ class ChangePartitionManager( batchHandleChangePartitionExecutors.submit { new Runnable { override def run(): Unit = { - val distinctPartitions = { - val requestSet = inBatchPartitions.get(shuffleId) - val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc) - requests.asScala.map { case (partitionId, request) => + val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc) + // For each partition only need handle one request that has the largest targetEpoch + val distinctPartitions = requests.asScala.collect { + case (partitionId, requestsForPartition) => locksForShuffle(partitionId % locksForShuffle.length).synchronized { - if (!requestSet.contains(partitionId) && requests.containsKey( - partitionId)) { - requestSet.add(partitionId) - Some(request.asScala.toArray.maxBy(_.epoch)) - } else { - None + requestsForPartition.asScala.maxBy(_.targetEpoch) match { + case request + if request.targetEpoch > inBatchPartitions.get( + shuffleId).getOrDefault(partitionId, -1) => + inBatchPartitions.get(shuffleId).put( + partitionId, + request.targetEpoch) + Some(request) + case _ => None } } - }.filter(_.isDefined).map(_.get).toArray - } + }.flatten.toArray if (distinctPartitions.nonEmpty) { handleRequestPartitions( shuffleId, @@ -143,9 +149,9 @@ class ChangePartitionManager( } private val inBatchShuffleIdRegisterFunc = - new util.function.Function[Int, ConcurrentHashMap.KeySetView[Int, java.lang.Boolean]]() { - override def apply(s: Int): ConcurrentHashMap.KeySetView[Int, java.lang.Boolean] = - ConcurrentHashMap.newKeySet[Int]() + new util.function.Function[Int, ConcurrentHashMap[Integer, Integer]]() { + override def apply(s: Int): ConcurrentHashMap[Integer, Integer] = + JavaUtils.newConcurrentHashMap() } private val locksRegisterFunc = new util.function.Function[Int, Array[AnyRef]] { @@ -154,20 +160,72 @@ class ChangePartitionManager( } } + def reportAndSplitPartitionIfNeeded( + shuffleId: Int, + partitionId: Int, + oldEpoch: Int, + oldPartition: PartitionLocation, + clientMaxEpoch: Int, + currentMaxEpochId: Int, + cause: Option[StatusCode], + requests: ConcurrentHashMap[Integer, util.Set[ChangePartitionRequest]]): Unit = { + val changed = + lifecycleManager.reportPartitionSplitOrRevived(shuffleId, partitionId, oldEpoch, cause) + val nextReserveSlotCount = + if (changed) { + lifecycleManager.getNextReserveSlotCount(shuffleId, partitionId) + } else { + 0 + } + if (nextReserveSlotCount > 0) { + logInfo( + s"Reserve slot for shuffleId: $shuffleId, partitionId: $partitionId, oldEpoch: $oldEpoch, clientMaxEpoch: $clientMaxEpoch," + + s" next reserve count: $nextReserveSlotCount, currentMaxEpochId: $currentMaxEpochId, targetEpochId: ${currentMaxEpochId + nextReserveSlotCount}") + val preAllocatePartitionRequest = ChangePartitionRequest( + null, + shuffleId, + partitionId, + clientMaxEpoch, + targetEpoch = currentMaxEpochId + nextReserveSlotCount, + oldPartition, + cause) + val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc) + locksForShuffle(partitionId % locksForShuffle.length).synchronized { + if (requests.containsKey(partitionId)) { + logDebug(s"[handleRequestPartitionLocation] For shuffle: $shuffleId, request for same " + + s"partition: $partitionId-$oldEpoch exists, register context.") + requests.get(partitionId).add(preAllocatePartitionRequest) + return + } else { + val set = new util.HashSet[ChangePartitionRequest]() + set.add(preAllocatePartitionRequest) + requests.put(partitionId, set) + logInfo(s"[handleRequestPartition][PreAllocate] for shuffleId: $shuffleId, partitionId: $partitionId, " + + s"oldEpoch: $oldEpoch, clientMaxEpoch: $clientMaxEpoch, oldHost: ${if (oldPartition == null) null + else oldPartition.getHost}, cause: $cause") + } + } + } + } + def handleRequestPartitionLocation( context: RequestLocationCallContext, shuffleId: Int, partitionId: Int, oldEpoch: Int, oldPartition: PartitionLocation, + clientMaxEpoch: Int, + reportOnly: Boolean, cause: Option[StatusCode] = None, isSegmentGranularityVisible: Boolean): Unit = { + val currentMaxEpochId = lifecycleManager.getCurrentMaxPartitionEpochId(shuffleId, partitionId) val changePartition = ChangePartitionRequest( context, shuffleId, partitionId, - oldEpoch, + clientMaxEpoch, + targetEpoch = currentMaxEpochId + 1, oldPartition, cause) // check if there exists request for the partition, if do just register @@ -178,28 +236,64 @@ class ChangePartitionManager( shuffleId, oldPartition, cause) - - val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc) - locksForShuffle(partitionId % locksForShuffle.length).synchronized { - if (requests.containsKey(partitionId)) { - logDebug(s"[handleRequestPartitionLocation] For shuffle: $shuffleId, request for same " + - s"partition: $partitionId-$oldEpoch exists, register context.") - requests.get(partitionId).add(changePartition) - return - } else { - getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { latestLoc => - context.reply( - partitionId, - StatusCode.SUCCESS, - Some(latestLoc), - lifecycleManager.workerStatusTracker.workerAvailableByLocation(oldPartition)) - logDebug(s"[handleRequestPartitionLocation]: For shuffle: $shuffleId," + - s" old partition: $partitionId-$oldEpoch, new partition: $latestLoc found, return it") + if (asyncSplitPartitionEnabled && batchHandleChangePartitionEnabled) { + reportAndSplitPartitionIfNeeded( + shuffleId, + partitionId, + oldEpoch, + oldPartition, + clientMaxEpoch, + currentMaxEpochId, + cause, + requests) + } + if (reportOnly) { + // if reportOnly is false, we will exclude failed workers in handleRequestPartitions + if (cause.isDefined) { + lifecycleManager.workerStatusTracker.excludeWorkerFromPartition( + shuffleId, + oldPartition, + cause.get) + } + val latestLocations = getLatestPartition(shuffleId, partitionId, clientMaxEpoch) + if (latestLocations.isDefined) { + logDebug( + s"New partition found, old partition $partitionId-$oldEpoch-$clientMaxEpoch return it." + + s" shuffleId: $shuffleId ${latestLocations.get}") + } + context.reply( + partitionId, + StatusCode.SUCCESS, + latestLocations, + lifecycleManager.workerStatusTracker.workerAvailableByLocation(oldPartition)) + return + } else { + val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc) + locksForShuffle(partitionId % locksForShuffle.length).synchronized { + if (requests.containsKey(partitionId)) { + logDebug(s"[handleRequestPartitionLocation] For shuffle: $shuffleId, request for same " + + s"partition: $partitionId-$oldEpoch exists, register context.") + requests.get(partitionId).add(changePartition) return + } else { + getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { latestLoc => + context.reply( + partitionId, + StatusCode.SUCCESS, + Some(latestLoc), + lifecycleManager.workerStatusTracker.workerAvailableByLocation(oldPartition)) + logDebug(s"[handleRequestPartitionLocation]: For shuffle: $shuffleId," + + s" old partition: $partitionId-$oldEpoch, new partition: $latestLoc found, return it") + return + } + val set = new util.HashSet[ChangePartitionRequest]() + set.add(changePartition) + requests.put(partitionId, set) + logInfo( + s"[handleRequestPartition][ChangePartition] for shuffleId: $shuffleId, partitionId: $partitionId, " + + s"oldEpoch: $oldEpoch, clientMaxEpoch: $clientMaxEpoch, oldHost: ${if (oldPartition == null) null + else oldPartition.getHost}, cause: $cause") } - val set = new util.HashSet[ChangePartitionRequest]() - set.add(changePartition) - requests.put(partitionId, set) } } if (!batchHandleChangePartitionEnabled) { @@ -210,12 +304,12 @@ class ChangePartitionManager( private def getLatestPartition( shuffleId: Int, partitionId: Int, - epoch: Int): Option[PartitionLocation] = { - val map = lifecycleManager.latestPartitionLocation.get(shuffleId) - if (map != null) { - val loc = map.get(partitionId) - if (loc != null && loc.getEpoch > epoch) { - return Some(loc) + clientMaxEpoch: Int): Option[Seq[PartitionLocation]] = { + val map = lifecycleManager.partitionLocationMonitors.getOrDefault(shuffleId, null) + if (map != null && map.getOrDefault(partitionId, null) != null) { + val activeLocations = map.get(partitionId).getActiveLocations(clientMaxEpoch) + if (activeLocations.nonEmpty) { + return Some(activeLocations) } } None @@ -228,9 +322,9 @@ class ChangePartitionManager( val requestsMap = changePartitionRequests.get(shuffleId) val changes = changePartitions.map { change => - s"${change.shuffleId}-${change.partitionId}-${change.epoch}" + s"${change.shuffleId}-${change.partitionId}-${change.clientMaxEpoch}-${change.targetEpoch}" }.mkString("[", ",", "]") - logWarning(s"Batch handle change partition for $changes") + logInfo(s"Batch handle change partition for $changes") // Exclude all failed workers if (changePartitions.exists(_.causes.isDefined) && !testRetryRevive) { @@ -244,43 +338,77 @@ class ChangePartitionManager( // remove together to reduce lock time def replySuccess(locations: Array[PartitionLocation]): Unit = { + val partitionsMap = locations.groupBy(_.getId) val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc) - locations.map { location => - locksForShuffle(location.getId % locksForShuffle.length).synchronized { + partitionsMap.map { case (partitionId, locations) => + locksForShuffle(partitionId % locksForShuffle.length).synchronized { + val largestEpochId = locations.maxBy(_.getEpoch).getEpoch + var stillProcessing = false if (batchHandleChangePartitionEnabled) { - inBatchPartitions.get(shuffleId).remove(location.getId) + inBatchPartitions.get(shuffleId).computeIfPresent( + partitionId, + (_, v) => + if (v <= largestEpochId) { + null + } else { + stillProcessing = true + v + }) } // Here one partition id can be remove more than once, // so need to filter null result before reply. - location -> Option(requestsMap.remove(location.getId)) + if (stillProcessing) { + None + } else { + Option(requestsMap.remove(partitionId)) + } } - }.foreach { case (newLocation, requests) => + }.foreach { requests => requests.map(_.asScala.toList.foreach(req => - req.context.reply( - req.partitionId, - StatusCode.SUCCESS, - Option(newLocation), - lifecycleManager.workerStatusTracker.workerAvailableByLocation(req.oldPartition)))) + if (req.context != null) { // only urgent request has context + req.context.reply( + req.partitionId, + StatusCode.SUCCESS, + getLatestPartition(shuffleId, req.partitionId, req.clientMaxEpoch), + lifecycleManager.workerStatusTracker.workerAvailableByLocation(req.oldPartition)) + })) } } // remove together to reduce lock time def replyFailure(status: StatusCode): Unit = { - changePartitions.map { changePartition => - val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc) - locksForShuffle(changePartition.partitionId % locksForShuffle.length).synchronized { - if (batchHandleChangePartitionEnabled) { - inBatchPartitions.get(shuffleId).remove(changePartition.partitionId) + val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc) + changePartitions.groupBy(_.partitionId).map { + case (partitionId, requests) => + locksForShuffle(partitionId % locksForShuffle.length).synchronized { + val largestEpochId = requests.maxBy(_.targetEpoch).targetEpoch + var stillProcessing = false + if (batchHandleChangePartitionEnabled) { + inBatchPartitions.get(shuffleId).computeIfPresent( + partitionId, + (_, v) => + if (v <= largestEpochId) { + null + } else { + stillProcessing = true + v + }) + } + if (stillProcessing) { + None + } else { + Option(requestsMap.remove(partitionId)) + } } - Option(requestsMap.remove(changePartition.partitionId)) - } }.foreach { requests => requests.map(_.asScala.toList.foreach(req => - req.context.reply( - req.partitionId, - status, - None, - lifecycleManager.workerStatusTracker.workerAvailableByLocation(req.oldPartition)))) + if (req.context != null) { // only urgent request has context + req.context.reply( + req.partitionId, + status, + getLatestPartition(shuffleId, req.partitionId, req.clientMaxEpoch), + lifecycleManager.workerStatusTracker.workerAvailableByLocation(req.oldPartition)) + })) } } @@ -370,6 +498,7 @@ class ChangePartitionManager( newlyAllocatedLocations, isSegmentGranularityVisible = isSegmentGranularityVisible)) { logError(s"[Update partition] failed for $shuffleId.") + // TODO: if partial success, maybe we could reply partial success. replyFailure(StatusCode.RESERVE_SLOTS_FAILED) return } @@ -389,14 +518,14 @@ class ChangePartitionManager( }) partitionLocationInfo.addPrimaryPartitions(primaryLocations) partitionLocationInfo.addReplicaPartitions(replicaLocations) - lifecycleManager.updateLatestPartitionLocations(shuffleId, primaryLocations) + lifecycleManager.addNewPartitionLocations(shuffleId, primaryLocations) // partition location can be null when call reserveSlotsWithRetry(). val locations = (primaryLocations.asScala ++ replicaLocations.asScala.map(_.getPeer)) .distinct.filter(_ != null) if (locations.nonEmpty) { val changes = locations.map { partition => - s"(partition ${partition.getId} epoch from ${partition.getEpoch - 1} to ${partition.getEpoch})" + s"(partition ${partition.getId} epoch changes to ${partition.getEpoch}, new host to ${partition.getHost})" }.mkString("[", ", ", "]") logInfo(s"[Update partition] success for " + s"shuffle $shuffleId, succeed partitions: " + @@ -414,12 +543,21 @@ class ChangePartitionManager( changePartitionRequests: List[ChangePartitionRequest], candidates: List[WorkerInfo]): WorkerResource = { val slots = new WorkerResource() - changePartitionRequests.foreach { partition => - lifecycleManager.allocateFromCandidates( - partition.partitionId, - partition.epoch, - candidates, - slots) + changePartitionRequests.foreach { request => + val epochIds = lifecycleManager.allocateEpochIdsAndUpdateCurrentMaxEpoch( + request.shuffleId, + request.partitionId, + request.targetEpoch) + logInfo(s"allocate for shuffleId: ${request.shuffleId}, " + + s"partitionId ${request.partitionId}, epochIds: ${epochIds.mkString("(", ", ", ")")}") + epochIds.foreach { + epochId => + lifecycleManager.allocateFromCandidates( + request.partitionId, + epochId, + candidates, + slots) + } } slots } diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 9b2ad93a6fa..de14957b9a6 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -47,7 +47,7 @@ import org.apache.celeborn.common.meta.{ApplicationMeta, ShufflePartitionLocatio import org.apache.celeborn.common.metrics.source.Role import org.apache.celeborn.common.network.protocol.{SerdeVersion, TransportMessagesHelper} import org.apache.celeborn.common.network.sasl.registration.RegistrationInfo -import org.apache.celeborn.common.protocol._ +import org.apache.celeborn.common.protocol.{PbPartitionSplitReport, _} import org.apache.celeborn.common.protocol.RpcNameConstants.WORKER_EP import org.apache.celeborn.common.protocol.message.ControlMessages._ import org.apache.celeborn.common.protocol.message.StatusCode @@ -94,9 +94,9 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends val shuffleFallbackCounts = JavaUtils.newConcurrentHashMap[String, java.lang.Long]() // maintain each shuffle's map relation of WorkerInfo and partition location val shuffleAllocatedWorkers = new ShuffleAllocatedWorkers - // shuffle id -> (partitionId -> newest PartitionLocation) - val latestPartitionLocation = - JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Int, PartitionLocation]]() + // shuffle id -> (partitionId -> PartitionLocationMonitor) + val partitionLocationMonitors = + JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Int, PartitionLocationMonitor]]() private val userIdentifier: UserIdentifier = IdentityProvider.instantiate(conf).provide() private val availableStorageTypes = conf.availableStorageTypes // app shuffle id -> LinkedHashMap of (app shuffle identifier, (shuffle id, fetch status)) @@ -126,6 +126,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends private val mockDestroyFailure = conf.testMockDestroySlotsFailure private val authEnabled = conf.authEnabledOnClient private var applicationMeta: ApplicationMeta = _ + + // shuffleId -> (partitionId, maxEpoch) + val currentMaxPartitionEpoch = + JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Int, Integer]]() + // shuffleId -> Number of map tasks, maxActiveLocation should be less than this value + private val shuffleId2NumMappers = JavaUtils.newConcurrentHashMap[Int, Int]() + val maxActiveLocation = conf.clientMaxActiveLocation + @VisibleForTesting def workerSnapshots(shuffleId: Int): util.Map[String, ShufflePartitionLocationInfo] = shuffleAllocatedWorkers.get(shuffleId) @@ -134,21 +142,62 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends def getUnregisterShuffleTime(): ConcurrentHashMap[Int, Long] = unregisterShuffleTime - val newMapFunc: function.Function[Int, ConcurrentHashMap[Int, PartitionLocation]] = - new util.function.Function[Int, ConcurrentHashMap[Int, PartitionLocation]]() { - override def apply(s: Int): ConcurrentHashMap[Int, PartitionLocation] = { - JavaUtils.newConcurrentHashMap[Int, PartitionLocation]() + val newMapFunc: function.Function[Int, ConcurrentHashMap[Int, PartitionLocationMonitor]] = + new util.function.Function[Int, ConcurrentHashMap[Int, PartitionLocationMonitor]]() { + override def apply(s: Int): ConcurrentHashMap[Int, PartitionLocationMonitor] = { + JavaUtils.newConcurrentHashMap[Int, PartitionLocationMonitor]() } } - def updateLatestPartitionLocations( + def addNewPartitionLocations( shuffleId: Int, locations: util.List[PartitionLocation]): Unit = { - val map = latestPartitionLocation.computeIfAbsent(shuffleId, newMapFunc) - locations.asScala.foreach(location => map.put(location.getId, location)) + val map = partitionLocationMonitors.computeIfAbsent(shuffleId, newMapFunc) + locations.asScala.foreach { + loc => + map + .computeIfAbsent( + loc.getId, + _ => + new PartitionLocationMonitor( + shuffleId, + loc.getId, + conf, + if (maxActiveLocation <= 0) { + shuffleId2NumMappers.get(shuffleId) + } else { + maxActiveLocation + })) + .addActiveLocationEpoch(loc) + } invalidateLatestMaxLocsCache(shuffleId) } + def initCurrentMaxPartitionEpoch( + shuffleId: Int, + locations: util.List[PartitionLocation]): Unit = { + val partitionMap = currentMaxPartitionEpoch.computeIfAbsent( + shuffleId, + _ => JavaUtils.newConcurrentHashMap[Int, Integer]) + locations.asScala.foreach { loc => + partitionMap.compute(loc.getId, (k, v) => Math.max(v, loc.getEpoch)) + } + } + + def reportPartitionSplitOrRevived( + shuffleId: Int, + partitionId: Int, + oldEpoch: Int, + cause: Option[StatusCode] = None): Boolean = { + partitionLocationMonitors.get(shuffleId).get(partitionId).receivePartitionSplitOrRevived( + oldEpoch, + cause) + } + + def getNextReserveSlotCount(shuffleId: Int, partitionId: Int): Int = { + partitionLocationMonitors.get(shuffleId).get(partitionId).nextReserveSlotCount + } + case class RegisterCallContext(context: RpcCallContext, partitionId: Int = -1) { def reply(response: PbRegisterShuffleResponse) = { context.reply(response) @@ -368,26 +417,33 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends partitionId, isSegmentGranularityVisible) + case pb: PbPartitionSplitReport => + val shuffleId = pb.getShuffleId + val mapIds = pb.getMapIdList + val partitionInfos = pb.getPartitionInfoList + + val (partitionIds, epochs, oldPartitions, causes, clientMaxEpochs) = + extractPartitionInfo(partitionInfos) + logDebug( + s"Received PartitionSplitReport request from ${context.senderAddress}, number of partitions ${partitionIds.size()}") + handleRevive( + context, + shuffleId, + mapIds, + partitionIds, + epochs, + oldPartitions, + causes, + clientMaxEpochs, + true) + case pb: PbRevive => val shuffleId = pb.getShuffleId val mapIds = pb.getMapIdList val partitionInfos = pb.getPartitionInfoList - val partitionIds = new util.ArrayList[Integer]() - val epochs = new util.ArrayList[Integer]() - val oldPartitions = new util.ArrayList[PartitionLocation]() - val causes = new util.ArrayList[StatusCode]() - (0 until partitionInfos.size()).foreach { idx => - val info = partitionInfos.get(idx) - partitionIds.add(info.getPartitionId) - epochs.add(info.getEpoch) - if (info.hasPartition) { - oldPartitions.add(PbSerDeUtils.fromPbPartitionLocation(info.getPartition)) - } else { - oldPartitions.add(null) - } - causes.add(StatusCode.fromValue(info.getStatus)) - } + val (partitionIds, epochs, oldPartitions, causes, clientMaxEpochs) = + extractPartitionInfo(partitionInfos) logDebug(s"Received Revive request, number of partitions ${partitionIds.size()}") handleRevive( context, @@ -396,7 +452,9 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends partitionIds, epochs, oldPartitions, - causes) + causes, + clientMaxEpochs, + false) case pb: PbPartitionSplit => val shuffleId = pb.getShuffleId @@ -411,7 +469,9 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends partitionId, epoch, oldPartition, - isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId)) + isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId), + clientMaxEpoch = epoch, + reportOnly = false) case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) => logTrace(s"Received MapperEnd TaskEnd request, " + @@ -625,7 +685,9 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends partitionId, -1, null, - isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId)) + isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId), + clientMaxEpoch = -1, + reportOnly = false) } } @@ -730,7 +792,6 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleId, candidatesWorkers, slots, - updateEpoch = false, isSegmentGranularityVisible) // If reserve slots failed, clear allocated resources, reply ReserveSlotFailed and return. @@ -742,12 +803,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends logDebug(s"ReserveSlots for $shuffleId success with details:$slots!") } // Forth, register shuffle success, update status + shuffleId2NumMappers.put(shuffleId, numMappers) val allocatedWorkers = JavaUtils.newConcurrentHashMap[String, ShufflePartitionLocationInfo]() slots.asScala.foreach { case (workerInfo, (primaryLocations, replicaLocations)) => val partitionLocationInfo = new ShufflePartitionLocationInfo(workerInfo) partitionLocationInfo.addPrimaryPartitions(primaryLocations) - updateLatestPartitionLocations(shuffleId, primaryLocations) + addNewPartitionLocations(shuffleId, primaryLocations) + initCurrentMaxPartitionEpoch(shuffleId, primaryLocations) partitionLocationInfo.addReplicaPartitions(replicaLocations) allocatedWorkers.put(workerInfo.toUniqueId, partitionLocationInfo) } @@ -774,7 +837,9 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends partitionIds: util.List[Integer], oldEpochs: util.List[Integer], oldPartitions: util.List[PartitionLocation], - causes: util.List[StatusCode]): Unit = { + causes: util.List[StatusCode], + clientMaxEpochs: util.List[Integer], + reportOnly: Boolean): Unit = { val contextWrapper = ChangeLocationsCallContext(context, partitionIds.size()) // If shuffle not registered, reply ShuffleNotRegistered and return @@ -814,11 +879,43 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends partitionIds.get(idx), oldEpochs.get(idx), oldPartitions.get(idx), + clientMaxEpochs.get(idx), + reportOnly, Some(causes.get(idx)), commitManager.isSegmentGranularityVisible(shuffleId)) } } + def extractPartitionInfo(partitionInfos: util.List[PbRevivePartitionInfo]): ( + util.List[Integer], + util.List[Integer], + util.List[PartitionLocation], + util.List[StatusCode], + util.List[Integer]) = { + val partitionIds = new util.ArrayList[Integer]() + val epochs = new util.ArrayList[Integer]() + val oldPartitions = new util.ArrayList[PartitionLocation]() + val causes = new util.ArrayList[StatusCode]() + val clientMaxEpochs = new util.ArrayList[Integer]() + + (0 until partitionInfos.size()).foreach { idx => + val info = partitionInfos.get(idx) + partitionIds.add(info.getPartitionId) + epochs.add(info.getEpoch) + + if (info.hasPartition) { + oldPartitions.add(PbSerDeUtils.fromPbPartitionLocation(info.getPartition)) + } else { + oldPartitions.add(null) + } + + causes.add(StatusCode.fromValue(info.getStatus.toByte)) + clientMaxEpochs.add(info.getClientMaxEpoch) + } + + (partitionIds, epochs, oldPartitions, causes, clientMaxEpochs) + } + private def handleMapperEnd( context: RpcCallContext, shuffleId: Int, @@ -1089,6 +1186,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends partitionLocationInfo.removeAllReplicaPartitions() } } + partitionLocationMonitors.get(shuffleId).forEach { + case (partitionId, partitionLocationMonitor) => + logInfo(s"Partition Location Monitor summary for shuffleId: ${shuffleId}, partitionId: ${partitionId}, ${partitionLocationMonitor.report}") + } } private def handleMapPartitionEnd( @@ -1395,7 +1496,6 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleId: Int, candidates: util.HashSet[WorkerInfo], slots: WorkerResource, - updateEpoch: Boolean = true, isSegmentGranularityVisible: Boolean = false): Boolean = { var requestSlots = slots val reserveSlotsMaxRetries = conf.clientReserveSlotsMaxRetries @@ -1445,8 +1545,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends // duplicated with existing partition locations. requestSlots = reallocateSlotsFromCandidates( failedPartitionLocations.values.toList, - retryCandidates.asScala.toList, - updateEpoch) + retryCandidates.asScala.toList) requestSlots.asScala.foreach { case (workerInfo, (retryPrimaryLocs, retryReplicaLocs)) => val (primaryPartitionLocations, replicaPartitionLocations) = @@ -1484,13 +1583,13 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends /** * Allocate a new primary/replica PartitionLocation pair from the current WorkerInfo list. * - * @param oldEpochId Current partition reduce location last epoch id + * @param epochId Target partition reduce location epoch id * @param candidates WorkerInfo list can be used to offer worker slots * @param slots Current WorkerResource */ def allocateFromCandidates( id: Int, - oldEpochId: Int, + epochId: Int, candidates: List[WorkerInfo], slots: WorkerResource, updateEpoch: Boolean = true): Unit = { @@ -1502,7 +1601,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends val primaryIndex = Random.nextInt(candidates.size) val primaryLocation = new PartitionLocation( id, - if (updateEpoch) oldEpochId + 1 else oldEpochId, + epochId, candidates(primaryIndex).host, candidates(primaryIndex).rpcPort, candidates(primaryIndex).pushPort, @@ -1522,7 +1621,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } val replicaLocation = new PartitionLocation( id, - if (updateEpoch) oldEpochId + 1 else oldEpochId, + epochId, candidates(replicaIndex).host, candidates(replicaIndex).rpcPort, candidates(replicaIndex).pushPort, @@ -1541,11 +1640,15 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends private def reallocateSlotsFromCandidates( oldPartitions: List[PartitionLocation], - candidates: List[WorkerInfo], - updateEpoch: Boolean = true): WorkerResource = { + candidates: List[WorkerInfo]): WorkerResource = { val slots = new WorkerResource() oldPartitions.foreach { partition => - allocateFromCandidates(partition.getId, partition.getEpoch, candidates, slots, updateEpoch) + allocateFromCandidates( + partition.getId, + partition.getEpoch, + candidates, + slots, + updateEpoch = false) } slots } @@ -1667,7 +1770,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends registeredShuffle.remove(shuffleId) registeringShuffleRequest.remove(shuffleId) shuffleAllocatedWorkers.remove(shuffleId) - latestPartitionLocation.remove(shuffleId) + partitionLocationMonitors.remove(shuffleId) commitManager.removeExpiredShuffle(shuffleId) changePartitionManager.removeExpiredShuffle(shuffleId) if (!batchRemoveExpiredShufflesEnabled) { @@ -1813,7 +1916,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends // Once a partition is released, it will be never needed anymore def releasePartition(shuffleId: Int, partitionId: Int): Unit = { commitManager.releasePartitionResource(shuffleId, partitionId) - val partitionLocation = latestPartitionLocation.get(shuffleId) + val partitionLocation = partitionLocationMonitors.get(shuffleId) if (partitionLocation != null) { partitionLocation.remove(partitionId) } @@ -1879,6 +1982,39 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends celebornSkewShuffleCheckCallback = Some(callback) } + def getCurrentMaxPartitionEpochId(shuffleId: Int, partitionId: Int): Int = { + currentMaxPartitionEpoch.computeIfAbsent( + shuffleId, + _ => JavaUtils.newConcurrentHashMap[Int, Integer]).getOrDefault(partitionId, -1) + } + + def allocateEpochIdsAndUpdateCurrentMaxEpoch( + shuffleId: Int, + partitionId: Int, + targetEpoch: Int): Array[Int] = { + var range = (-1, -1) + currentMaxPartitionEpoch.computeIfAbsent( + shuffleId, + _ => JavaUtils.newConcurrentHashMap[Int, Integer]).compute( + partitionId, + (_, currentMaxEpoch) => { + if (currentMaxEpoch == null) { + range = (0, targetEpoch) + targetEpoch + } else { + range = (currentMaxEpoch + 1, targetEpoch) + math.max(currentMaxEpoch, targetEpoch) + } + }) + logDebug( + s"allocateEpochIdsAndUpdateCurrentMaxEpoch, shuffleId ${shuffleId}, partitionId ${partitionId}, from ${range._1} to ${range._2}") + if (range._1 <= range._2) { + (range._1 to range._2).toArray + } else { + Array.empty + } + } + // Initialize at the end of LifecycleManager construction. initialize() diff --git a/client/src/main/scala/org/apache/celeborn/client/PartitionLocationMonitor.scala b/client/src/main/scala/org/apache/celeborn/client/PartitionLocationMonitor.scala new file mode 100644 index 00000000000..6cd2ce7cae7 --- /dev/null +++ b/client/src/main/scala/org/apache/celeborn/client/PartitionLocationMonitor.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client + +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ +import scala.math.ceil + +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.common.network.util.ByteUnit +import org.apache.celeborn.common.protocol.PartitionLocation +import org.apache.celeborn.common.protocol.message.StatusCode +import org.apache.celeborn.common.util.ConcurrentSkipListMapWithTracker + +class PartitionLocationMonitor( + shuffleId: Int, + partitionId: Int, + conf: CelebornConf, + maxActiveLocation: Int) extends Logging { + + private val timeWindowsInSecs = conf.clientActiveFullLocationTimeWindowSecs + private val intervalPerBucketInMills = conf.clientActiveFullLocationIntervalPerBucketMs + private val expectedWorkerSpeedMBPerSec = conf.clientExpectedWorkerSpeedMBPerSecond + private lazy val activeFullLocationHub = + new PartitionSplitTimeSlidingHub(timeWindowsInSecs.toInt, intervalPerBucketInMills.toInt) + // epochId -> partitionLocation + private val activeLocationEpochs = new ConcurrentSkipListMapWithTracker[Int, PartitionLocation]() + private val softSplitEpochIds: java.util.Set[Int] = ConcurrentHashMap.newKeySet() + private val exceptionEpochIds: java.util.Set[Int] = ConcurrentHashMap.newKeySet() + private val softSplitSizeMB = + ByteUnit.BYTE.convertTo(conf.shufflePartitionSplitThreshold, ByteUnit.MiB) + private val hardSplitSizeMB = softSplitSizeMB * 3 // TODO: @漠云 后续考虑进行修改 + + def addActiveLocationEpoch(partitionLocation: PartitionLocation): Unit = { + activeLocationEpochs.put(partitionLocation.getEpoch, partitionLocation) + } + + def getActiveLocations(clientMaxEpoch: Int): Seq[PartitionLocation] = { + activeLocationEpochs.tailMap(clientMaxEpoch + 1).values().asScala.toSeq + } + + // if active full location hub changed, return true + def receivePartitionSplitOrRevived(epoch: Int, cause: Option[StatusCode]): Boolean = { + var changed = false + if (cause.isDefined) { + if (cause.get == StatusCode.SOFT_SPLIT) { // soft_split + if (activeLocationEpochs.remove(epoch) || exceptionEpochIds.remove(epoch)) { + activeFullLocationHub.add(new PartitionSplitNode(softSplitSizeMB)) + softSplitEpochIds.add(epoch) + changed = true + } + } else if (cause.get == StatusCode.HARD_SPLIT) { + if (activeLocationEpochs.remove(epoch) || exceptionEpochIds.remove(epoch)) { + activeFullLocationHub.add(new PartitionSplitNode(hardSplitSizeMB)) + changed = true + } else if (softSplitEpochIds.remove(epoch)) { + activeFullLocationHub.add(new PartitionSplitNode(hardSplitSizeMB - softSplitSizeMB)) + changed = true + } + } else { + if (activeLocationEpochs.remove(epoch)) { + exceptionEpochIds.add(epoch) + changed = true + } + } + } + if (changed) { + logInfo( + s"Receive shuffleId: $shuffleId, partitionId: $partitionId, epochId: $epoch, cause: $cause") + } + changed + } + + def activeLocationCount: Int = { + activeLocationEpochs.size + } + + def nextReserveSlotCount: Int = { + val pushSpeed = activeFullLocationHub.getActiveFullLocationSizeMBPerSec() + Math.max( + Math.min( + maxActiveLocation, + ceil(pushSpeed.toDouble / expectedWorkerSpeedMBPerSec)).toInt - activeLocationCount, + 0) + } + + def report: String = { + activeLocationEpochs.report() + } +} diff --git a/client/src/main/scala/org/apache/celeborn/client/PartitionSplitTimeSlidingHub.scala b/client/src/main/scala/org/apache/celeborn/client/PartitionSplitTimeSlidingHub.scala new file mode 100644 index 00000000000..bc801efdc0b --- /dev/null +++ b/client/src/main/scala/org/apache/celeborn/client/PartitionSplitTimeSlidingHub.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client + +import org.apache.celeborn.common.util.TimeSlidingHub + +class PartitionSplitNode(var value: Long) extends TimeSlidingHub.TimeSlidingNode { + + override def combineNode(node: TimeSlidingHub.TimeSlidingNode): Unit = { + value += node.asInstanceOf[PartitionSplitNode].value + } + + /** Minus the value from the {@param node}. */ + override def separateNode(node: TimeSlidingHub.TimeSlidingNode): Unit = { + value -= node.asInstanceOf[PartitionSplitNode].value + } + + override def clone: PartitionSplitNode = new PartitionSplitNode(value) +} + +class PartitionSplitTimeSlidingHub(timeWindowsInSecs: Int, intervalPerBucketInMills: Int) + extends TimeSlidingHub[PartitionSplitNode](timeWindowsInSecs, intervalPerBucketInMills) { + override protected def newEmptyNode(): PartitionSplitNode = { + new PartitionSplitNode(0) + } + + def getActiveFullLocationSizeMBPerSec(): Int = { + val currentSizeMB = sum().getLeft.value + if (currentSizeMB > 0) { + currentSizeMB.toInt / timeWindowsInSecs + } else { + 0 + } + } +} diff --git a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala index 02f67df7e1f..79bc175eac2 100644 --- a/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala +++ b/client/src/main/scala/org/apache/celeborn/client/RequestLocationCallContext.scala @@ -18,6 +18,7 @@ package org.apache.celeborn.client import java.util +import java.util.concurrent.atomic.LongAdder import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.protocol.PartitionLocation @@ -30,7 +31,7 @@ trait RequestLocationCallContext { def reply( partitionId: Int, status: StatusCode, - partitionLocationOpt: Option[PartitionLocation], + partitionLocationsOpt: Option[Seq[PartitionLocation]], available: Boolean): Unit } @@ -42,6 +43,7 @@ case class ChangeLocationsCallContext( val newLocs = JavaUtils.newConcurrentHashMap[Integer, (StatusCode, Boolean, Seq[PartitionLocation])]( partitionCount) + private val requestCount = new LongAdder def markMapperEnd(mapId: Int): Unit = this.synchronized { endedMapIds.add(mapId) @@ -50,15 +52,21 @@ case class ChangeLocationsCallContext( override def reply( partitionId: Int, status: StatusCode, - partitionLocationOpt: Option[PartitionLocation], + partitionLocationsOpt: Option[Seq[PartitionLocation]], available: Boolean): Unit = this.synchronized { - if (newLocs.containsKey(partitionId)) { - logError(s"PartitionId $partitionId already exists!") + if (newLocs.containsKey(partitionId) && newLocs.get(partitionId)._3 != null && newLocs.get( + partitionId)._3.length != partitionLocationsOpt.getOrElse(Seq.empty).length) { + logError(s"PartitionId $partitionId already exists! " + + s"${newLocs.get(partitionId)._3.length} != ${partitionLocationsOpt.getOrElse(Seq.empty).length}") } - //TODO fix later -// newLocs.put(partitionId, (status, available, partitionLocationOpt.getOrElse(null))) - - if (newLocs.size() == partitionCount || StatusCode.SHUFFLE_NOT_REGISTERED == status + if (partitionLocationsOpt.isDefined && (status == StatusCode.RESERVE_SLOTS_FAILED || status == StatusCode.SLOT_NOT_AVAILABLE)) { + newLocs.put(partitionId, (StatusCode.SUCCESS, available, partitionLocationsOpt.orNull)) + } else { + newLocs.put(partitionId, (status, available, partitionLocationsOpt.orNull)) + } + requestCount.increment() + if (requestCount.intValue() == partitionCount + || StatusCode.SHUFFLE_NOT_REGISTERED == status || StatusCode.STAGE_ENDED == status) { context.reply(ChangeLocationResponse(endedMapIds, newLocs)) } @@ -69,11 +77,11 @@ case class ApplyNewLocationCallContext(context: RpcCallContext) extends RequestL override def reply( partitionId: Int, status: StatusCode, - partitionLocationOpt: Option[PartitionLocation], + partitionLocationsOpt: Option[Seq[PartitionLocation]], available: Boolean): Unit = { - partitionLocationOpt match { - case Some(partitionLocation) => - context.reply(RegisterShuffleResponse(status, Array(partitionLocation))) + partitionLocationsOpt match { + case Some(partitionLocations) => + context.reply(RegisterShuffleResponse(status, partitionLocations.toArray)) case None => context.reply(RegisterShuffleResponse(status, Array.empty)) } } diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/ReviveRequest.java b/common/src/main/java/org/apache/celeborn/common/protocol/ReviveRequest.java index e10b7174267..cd19edcc0f6 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/ReviveRequest.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/ReviveRequest.java @@ -17,9 +17,10 @@ package org.apache.celeborn.common.protocol; -import org.apache.celeborn.common.protocol.message.StatusCode; import java.util.Objects; +import org.apache.celeborn.common.protocol.message.StatusCode; + public class ReviveRequest { public int shuffleId; public int mapId; @@ -32,14 +33,14 @@ public class ReviveRequest { public volatile int reviveStatus; public ReviveRequest( - int shuffleId, - int mapId, - int attemptId, - int partitionId, - PartitionLocation loc, - StatusCode cause, - int clientMaxEpoch, - boolean urgent) { + int shuffleId, + int mapId, + int attemptId, + int partitionId, + PartitionLocation loc, + StatusCode cause, + int clientMaxEpoch, + boolean urgent) { this.shuffleId = shuffleId; this.mapId = mapId; this.attemptId = attemptId; @@ -53,17 +54,26 @@ public ReviveRequest( @Override public String toString() { - return "ReviveRequest{" + - "shuffleId=" + shuffleId + - ", mapId=" + mapId + - ", attemptId=" + attemptId + - ", partitionId=" + partitionId + - ", loc=" + loc + - ", clientMaxEpoch=" + clientMaxEpoch + - ", cause=" + cause + - ", urgent=" + urgent + - ", reviveStatus=" + reviveStatus + - '}'; + return "ReviveRequest{" + + "shuffleId=" + + shuffleId + + ", mapId=" + + mapId + + ", attemptId=" + + attemptId + + ", partitionId=" + + partitionId + + ", loc=" + + loc + + ", clientMaxEpoch=" + + clientMaxEpoch + + ", cause=" + + cause + + ", urgent=" + + urgent + + ", reviveStatus=" + + reviveStatus + + '}'; } @Override @@ -71,11 +81,20 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ReviveRequest that = (ReviveRequest) o; - return shuffleId == that.shuffleId && mapId == that.mapId && attemptId == that.attemptId && partitionId == that.partitionId && clientMaxEpoch == that.clientMaxEpoch && urgent == that.urgent && reviveStatus == that.reviveStatus && loc == that.loc && cause == that.cause; + return shuffleId == that.shuffleId + && mapId == that.mapId + && attemptId == that.attemptId + && partitionId == that.partitionId + && clientMaxEpoch == that.clientMaxEpoch + && urgent == that.urgent + && reviveStatus == that.reviveStatus + && loc == that.loc + && cause == that.cause; } @Override public int hashCode() { - return Objects.hash(shuffleId, mapId, attemptId, partitionId, loc, clientMaxEpoch, cause, urgent, reviveStatus); + return Objects.hash( + shuffleId, mapId, attemptId, partitionId, loc, clientMaxEpoch, cause, urgent, reviveStatus); } } diff --git a/common/src/main/java/org/apache/celeborn/common/util/ConcurrentSkipListMapWithTracker.java b/common/src/main/java/org/apache/celeborn/common/util/ConcurrentSkipListMapWithTracker.java new file mode 100644 index 00000000000..026cf0f60ba --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/util/ConcurrentSkipListMapWithTracker.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.util; + +import java.util.concurrent.ConcurrentNavigableMap; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +public class ConcurrentSkipListMapWithTracker { + private final ConcurrentSkipListMap map = new ConcurrentSkipListMap<>(); + private final long firstTime = System.nanoTime(); + private long lastTime = firstTime; + private int maxCount = 1; + private long totalCountAndTime = 0; + + public V put(K key, V value) { + return map.computeIfAbsent( + key, + (v) -> { + recordChange(); + return value; + }); + } + + public boolean remove(K key) { + AtomicBoolean exists = new AtomicBoolean(false); + map.computeIfPresent( + key, + (k, v) -> { + recordChange(); + exists.set(true); + return null; + }); + return exists.get(); + } + + public ConcurrentNavigableMap tailMap(K key) { + return map.tailMap(key); + } + + public int size() { + return map.size(); + } + + private void recordChange() { + int currentCount = map.size(); + if (currentCount > maxCount) { + maxCount = currentCount; + } + long currentTime = System.nanoTime(); + long duration = currentTime - lastTime; + lastTime = currentTime; + totalCountAndTime += currentCount * TimeUnit.NANOSECONDS.toMillis(duration); + } + + public String report() { + recordChange(); + double averageSize = + (double) totalCountAndTime / TimeUnit.NANOSECONDS.toMillis(lastTime - firstTime); + return String.format( + "maxActiveLocationsCount: %d, avgActiveLocationsCount: %.2f", maxCount, averageSize); + } +} diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TimeSlidingHub.java b/common/src/main/java/org/apache/celeborn/common/util/TimeSlidingHub.java similarity index 94% rename from worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TimeSlidingHub.java rename to common/src/main/java/org/apache/celeborn/common/util/TimeSlidingHub.java index d744f583e95..01eb90b2f69 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TimeSlidingHub.java +++ b/common/src/main/java/org/apache/celeborn/common/util/TimeSlidingHub.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.celeborn.service.deploy.worker.congestcontrol; +package org.apache.celeborn.common.util; import java.util.Iterator; import java.util.concurrent.LinkedBlockingDeque; @@ -51,7 +51,7 @@ public interface TimeSlidingNode extends Cloneable { } // 1 second. - protected final int intervalPerBucketInMills = 1000; + protected final int intervalPerBucketInMills; protected final int timeWindowsInMills; private final int maxQueueSize; private Pair sumInfo; @@ -59,7 +59,12 @@ public interface TimeSlidingNode extends Cloneable { private final LinkedBlockingDeque> _deque; public TimeSlidingHub(int timeWindowsInSecs) { + this(timeWindowsInSecs, 1000); // intervalPerBucketInMills default 1 second + } + + public TimeSlidingHub(int timeWindowsInSecs, int intervalPerBucketInMills) { this._deque = new LinkedBlockingDeque<>(); + this.intervalPerBucketInMills = intervalPerBucketInMills; this.maxQueueSize = timeWindowsInSecs * 1000 / intervalPerBucketInMills; this.timeWindowsInMills = maxQueueSize * intervalPerBucketInMills; this.sumInfo = Pair.of(newEmptyNode(), 0); diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 1be7a15e693..8a7edf873c6 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -29,6 +29,7 @@ import scala.util.matching.Regex import io.netty.channel.epoll.Epoll +import org.apache.celeborn.common.CelebornConf.{CLIENT_ACTIVE_FULL_LOCATION_INTERVAL_PER_BUCKET, CLIENT_ACTIVE_FULL_LOCATION_TIME_WINDOW, CLIENT_ASYNC_SPLIT_PARTITION_ENABLED, CLIENT_EXPECTED_WORKER_SPEED_MB_PER_SECOND} import org.apache.celeborn.common.authentication.AnonymousAuthenticationProviderImpl import org.apache.celeborn.common.identity.{DefaultIdentityProvider, HadoopBasedIdentityProvider, IdentityProvider} import org.apache.celeborn.common.internal.Logging @@ -1050,7 +1051,11 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se get(CLIENT_PUSH_SENDBUFFERPOOL_CHECKEXPIREINTERVAL) def clientAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = get(CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED) - + def clientActiveFullLocationTimeWindowSecs: Long = get(CLIENT_ACTIVE_FULL_LOCATION_TIME_WINDOW) + def clientActiveFullLocationIntervalPerBucketMs: Long = + get(CLIENT_ACTIVE_FULL_LOCATION_INTERVAL_PER_BUCKET) + def clientExpectedWorkerSpeedMBPerSecond: Int = get(CLIENT_EXPECTED_WORKER_SPEED_MB_PER_SECOND) + def clientMaxActiveLocation: Int = get(CLIENT_MAX_ACTIVE_LOCATION) // ////////////////////////////////////////////////////// // Client Shuffle // // ////////////////////////////////////////////////////// @@ -1111,6 +1116,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def registerShuffleFilterExcludedWorkerEnabled: Boolean = get(REGISTER_SHUFFLE_FILTER_EXCLUDED_WORKER_ENABLED) def reviseLostShufflesEnabled: Boolean = get(REVISE_LOST_SHUFFLES_ENABLED) + def asyncSplitPartitionEnabled: Boolean = get(CLIENT_ASYNC_SPLIT_PARTITION_ENABLED) // ////////////////////////////////////////////////////// // Worker // @@ -6019,6 +6025,14 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(false) + val CLIENT_ASYNC_SPLIT_PARTITION_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.client.async.split.partition.enabled") + .categories("client") + .version("0.6.0") + .doc("When enabled, the ChangePartitionManager will asynchronously split partitions based on the speed for partition splitting.") + .booleanConf + .createWithDefault(false) + val NETWORK_IO_SASL_TIMEOUT: ConfigEntry[Long] = buildConf("celeborn..io.saslTimeout") .categories("network") @@ -6123,6 +6137,38 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(false) + val CLIENT_ACTIVE_FULL_LOCATION_TIME_WINDOW: ConfigEntry[Long] = + buildConf("celeborn.client.active.full.location.time.window") + .categories("client") + .version("0.6.0") + .doc("The time window to check if the full location is active.") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("180s") + + val CLIENT_ACTIVE_FULL_LOCATION_INTERVAL_PER_BUCKET: ConfigEntry[Long] = + buildConf("celeborn.client.active.full.location.interval.per.bucket") + .categories("client") + .version("0.6.0") + .doc("The interval to check if the full location is active per bucket.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("10s") + + val CLIENT_EXPECTED_WORKER_SPEED_MB_PER_SECOND: ConfigEntry[Int] = + buildConf("celeborn.client.expected.worker.speed.mb.per.second") + .categories("client") + .version("0.6.0") + .doc("The expected speed of a worker in MB/s.") + .intConf + .createWithDefault(10) + + val CLIENT_MAX_ACTIVE_LOCATION: ConfigEntry[Int] = + buildConf("celeborn.client.max.active.location") + .categories("client") + .version("0.6.0") + .doc("The max number of active location. If the number is -1, it will be set to the number of mappers.") + .intConf + .createWithDefault(-1) + // SSL Configs val SSL_ENABLED: ConfigEntry[Boolean] = diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 79148f21d92..7c7f1464c8b 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -237,9 +237,9 @@ object ControlMessages extends Logging { object PartitionSplitReport { def apply( - shuffleId: Int, - mapIds: util.Set[Integer], - reviveRequests: util.Collection[ReviveRequest]): PbPartitionSplitReport = { + shuffleId: Int, + mapIds: util.Set[Integer], + reviveRequests: util.Collection[ReviveRequest]): PbPartitionSplitReport = { val builder = PbPartitionSplitReport.newBuilder() .setShuffleId(shuffleId) .addAllMapId(mapIds) @@ -291,7 +291,8 @@ object ControlMessages extends Logging { .setOldAvailable(available) if (locs != null) { locs.foreach(loc => - pbChangeLocationPartitionInfoBuilder.addPartition(PbSerDeUtils.toPbPartitionLocation(loc))) + pbChangeLocationPartitionInfoBuilder.addPartition( + PbSerDeUtils.toPbPartitionLocation(loc))) } builder.addPartitionInfo(pbChangeLocationPartitionInfoBuilder.build()) } diff --git a/common/src/main/scala/org/apache/celeborn/common/util/CelebornHadoopUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/CelebornHadoopUtils.scala index 58fde690bfd..07eaa404067 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/CelebornHadoopUtils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/CelebornHadoopUtils.scala @@ -25,7 +25,6 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.celeborn.common.CelebornConf -import org.apache.celeborn.common.CelebornConf.{OSS_ACCESS_KEY, OSS_SECRET_KEY} import org.apache.celeborn.common.exception.CelebornException import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.protocol.StorageInfo diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala index 653773fd791..1f813ddda12 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala @@ -76,8 +76,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite val reserveSlotsSuccess = lifecycleManager.reserveSlotsWithRetry( shuffleId, new util.HashSet(res.workerResource.keySet()), - res.workerResource, - updateEpoch = false) + res.workerResource) if (reserveSlotsSuccess) { val allocatedWorkers = @@ -88,7 +87,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite partitionLocationInfo.addPrimaryPartitions(primaryLocations) partitionLocationInfo.addReplicaPartitions(replicaLocations) allocatedWorkers.put(workerInfo.toUniqueId, partitionLocationInfo) - lifecycleManager.updateLatestPartitionLocations(shuffleId, primaryLocations) + lifecycleManager.addNewPartitionLocations(shuffleId, primaryLocations) } lifecycleManager.shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers) } @@ -104,6 +103,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite shuffleId, partitionId, -1, + 0, null, None) changePartitionManager.changePartitionRequests.computeIfAbsent( @@ -151,8 +151,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite val reserveSlotsSuccess = lifecycleManager.reserveSlotsWithRetry( shuffleId, new util.HashSet(res.workerResource.keySet()), - res.workerResource, - updateEpoch = false) + res.workerResource) if (reserveSlotsSuccess) { val allocatedWorkers = @@ -163,7 +162,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite partitionLocationInfo.addPrimaryPartitions(primaryLocations) partitionLocationInfo.addReplicaPartitions(replicaLocations) allocatedWorkers.put(workerInfo.toUniqueId, partitionLocationInfo) - lifecycleManager.updateLatestPartitionLocations(shuffleId, primaryLocations) + lifecycleManager.addNewPartitionLocations(shuffleId, primaryLocations) } lifecycleManager.shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers) } @@ -206,6 +205,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite shuffleId, partitionId, -1, + 0, null, None) changePartitionManager.changePartitionRequests.computeIfAbsent( @@ -259,8 +259,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite val reserveSlotsSuccess = lifecycleManager.reserveSlotsWithRetry( shuffleId, new util.HashSet(res.workerResource.keySet()), - res.workerResource, - updateEpoch = false) + res.workerResource) if (reserveSlotsSuccess) { val allocatedWorkers = @@ -286,6 +285,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite shuffleId, partitionId, -1, + 0, null, None) changePartitionManager.changePartitionRequests.computeIfAbsent( @@ -330,8 +330,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite val reserveSlotsSuccess = lifecycleManager.reserveSlotsWithRetry( shuffleId, new util.HashSet(res.workerResource.keySet()), - res.workerResource, - updateEpoch = false) + res.workerResource) if (reserveSlotsSuccess) { val allocatedWorkers = @@ -358,6 +357,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite shuffleId, partitionId, -1, + 0, null, None) changePartitionManager.changePartitionRequests.computeIfAbsent( diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala index 39b54b5f638..0c28bb33fed 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala @@ -67,8 +67,7 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC lifecycleManager.reserveSlotsWithRetry( shuffleId, new util.HashSet(res.workerResource.keySet()), - res.workerResource, - updateEpoch = false) + res.workerResource) lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false) 0 until 10 foreach { partitionId => @@ -123,8 +122,7 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC lifecycleManager.reserveSlotsWithRetry( shuffleId, new util.HashSet(res.workerResource.keySet()), - res.workerResource, - updateEpoch = false) + res.workerResource) lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false) 0 until 10 foreach { partitionId => @@ -193,8 +191,7 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC lifecycleManager.reserveSlotsWithRetry( shuffleId, new util.HashSet(res.workerResource.keySet()), - res.workerResource, - updateEpoch = false) + res.workerResource) lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false) 0 until 1000 foreach { partitionId => @@ -253,8 +250,7 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC lifecycleManager.reserveSlotsWithRetry( shuffleId, new util.HashSet(res.workerResource.keySet()), - res.workerResource, - updateEpoch = false) + res.workerResource) lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala index e8326820209..f0a0b75245e 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala @@ -64,8 +64,7 @@ class LifecycleManagerDestroySlotsSuite extends WithShuffleClientSuite with Mini lifecycleManager.reserveSlotsWithRetry( shuffleId, new util.HashSet(res.workerResource.keySet()), - res.workerResource, - updateEpoch = false) + res.workerResource) val slotsToDestroy = new WorkerResource val destroyWorkers = workerInfos.keySet.take(2) @@ -106,8 +105,7 @@ class LifecycleManagerDestroySlotsSuite extends WithShuffleClientSuite with Mini lifecycleManager.reserveSlotsWithRetry( shuffleId, new util.HashSet(res.workerResource.keySet()), - res.workerResource, - updateEpoch = false) + res.workerResource) val slotsToDestroy = new WorkerResource val destroyWorkers = workerInfos.keySet.take(2) @@ -148,8 +146,7 @@ class LifecycleManagerDestroySlotsSuite extends WithShuffleClientSuite with Mini lifecycleManager.reserveSlotsWithRetry( shuffleId, new util.HashSet(res.workerResource.keySet()), - res.workerResource, - updateEpoch = false) + res.workerResource) val slotsToDestroy = new WorkerResource val destroyWorkers = workerInfos.keySet.take(2) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala index 87b05c76630..ad6069911a4 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala @@ -99,6 +99,25 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu lifecycleManager.stop() } + test("allocateEpochIdsAndUpdateCurrentMaxEpoch") { + val celebornConf = new CelebornConf() + val lifecycleManager = new LifecycleManager(s"app-${System.currentTimeMillis()}", celebornConf) + val shuffleId = 0 + val partitionId = 0 + val r1 = lifecycleManager.allocateEpochIdsAndUpdateCurrentMaxEpoch(shuffleId, partitionId, 0) + assert(r1.length == 1) + assert(r1(0) == 0) + val r2 = lifecycleManager.allocateEpochIdsAndUpdateCurrentMaxEpoch(shuffleId, partitionId, 3) + assert(r2.length == 3) + assert(r2(0) == 1) + assert(r2(1) == 2) + assert(r2(2) == 3) + val r3 = lifecycleManager.allocateEpochIdsAndUpdateCurrentMaxEpoch(shuffleId, partitionId, 2) + assert(r3.length == 0) + val r4 = lifecycleManager.allocateEpochIdsAndUpdateCurrentMaxEpoch(shuffleId, partitionId, 3) + assert(r4.length == 0) + } + override def afterAll(): Unit = { logInfo("all test complete , stop celeborn mini cluster") shutdownMiniCluster() diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/PartitionLocationMonitorSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/PartitionLocationMonitorSuite.scala new file mode 100644 index 00000000000..1a658ab200f --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/PartitionLocationMonitorSuite.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.client + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.ArrayBuffer + +import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.client.PartitionLocationMonitor +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.protocol.PartitionLocation +import org.apache.celeborn.common.protocol.message.StatusCode + +class PartitionLocationMonitorSuite extends CelebornFunSuite { + private def newPartitionLocation(epochId: Int): PartitionLocation = { + new PartitionLocation(0, epochId, "localhost", 1, 1, 1, 1, PartitionLocation.Mode.PRIMARY); + } + + test("quickly reserve slots") { + val conf: CelebornConf = new CelebornConf() + val partitionLocationMonitor = new PartitionLocationMonitor(0, 0, conf, 10) + val epochGenerator = new AtomicInteger + val epochIds = new ArrayBuffer[Int] + // start with 1 slot + val epochId = epochGenerator.incrementAndGet() + epochIds += epochId + partitionLocationMonitor.addActiveLocationEpoch(newPartitionLocation(epochId)) + assert(partitionLocationMonitor.activeLocationCount == 1) + + // 1 slot is split, nextReserveSlotCount should be 1 + epochIds.foreach { + epochId => + partitionLocationMonitor.receivePartitionSplitOrRevived( + epochId, + Some(StatusCode.SOFT_SPLIT)) + } + assert(partitionLocationMonitor.activeLocationCount == 0) + var nextReserveSlotCount = partitionLocationMonitor.nextReserveSlotCount + assert(nextReserveSlotCount == 1) + + // reserve 1 new slot + for (i <- 0 until nextReserveSlotCount) { + val epochId = epochGenerator.incrementAndGet() + partitionLocationMonitor.addActiveLocationEpoch(newPartitionLocation(epochId)) + epochIds += epochId + } + assert(partitionLocationMonitor.activeLocationCount == 1) + + // 1 new slot is split, nextReserveSlotCount should be 2 + epochIds.foreach { epochId => + partitionLocationMonitor.receivePartitionSplitOrRevived(epochId, Some(StatusCode.SOFT_SPLIT)) + } + assert(partitionLocationMonitor.activeLocationCount == 0) + nextReserveSlotCount = partitionLocationMonitor.nextReserveSlotCount + assert(nextReserveSlotCount == 2) + + // reserve 2 new slots + for (i <- 0 until nextReserveSlotCount) { + val epochId = epochGenerator.incrementAndGet() + partitionLocationMonitor.addActiveLocationEpoch(newPartitionLocation(epochId)) + epochIds += epochId + } + assert(partitionLocationMonitor.activeLocationCount == 2) + + // 2 new slots are split, nextReserveSlotCount should be 3 + epochIds.foreach { epochId => + partitionLocationMonitor.receivePartitionSplitOrRevived(epochId, Some(StatusCode.SOFT_SPLIT)) + } + assert(partitionLocationMonitor.activeLocationCount == 0) + nextReserveSlotCount = partitionLocationMonitor.nextReserveSlotCount + assert(nextReserveSlotCount == 3) + } + + test("slowly reserve slots") { + val conf: CelebornConf = new CelebornConf() + .set(CelebornConf.CLIENT_ACTIVE_FULL_LOCATION_TIME_WINDOW.key, "2s") + .set(CelebornConf.CLIENT_ACTIVE_FULL_LOCATION_INTERVAL_PER_BUCKET.key, "1s") + .set(CelebornConf.CLIENT_EXPECTED_WORKER_SPEED_MB_PER_SECOND.key, "1024") + val partitionLocationMonitor = new PartitionLocationMonitor(0, 0, conf, 10) + val epochGenerator = new AtomicInteger + val epochIds = new ArrayBuffer[Int] + + // reserve 10 new slot + for (i <- 0 until 10) { + val epochId = epochGenerator.incrementAndGet() + partitionLocationMonitor.addActiveLocationEpoch(newPartitionLocation(epochId)) + epochIds += epochId + } + + epochIds.slice(0, 5).foreach { epochId => + partitionLocationMonitor.receivePartitionSplitOrRevived(epochId, Some(StatusCode.SOFT_SPLIT)) + } + assert(partitionLocationMonitor.activeLocationCount == 5) + + // sleep timeWindowsInSecs = 2s, we expect all active full locations are expired + Thread.sleep(2 * 1000L) + assert(partitionLocationMonitor.activeLocationCount == 5) + + // 3 slots are split, nextReserveSlotCount should be 0 + epochIds.slice(5, 8).foreach { epochId => + partitionLocationMonitor.receivePartitionSplitOrRevived(epochId, Some(StatusCode.SOFT_SPLIT)) + } + assert(partitionLocationMonitor.activeLocationCount == 2) + assert(partitionLocationMonitor.nextReserveSlotCount == 0) + } + + test("max active location limit") { + val conf: CelebornConf = new CelebornConf() + val partitionLocationMonitor = new PartitionLocationMonitor(0, 0, conf, 3) + val epochGenerator = new AtomicInteger + val epochIds = new ArrayBuffer[Int] + + // start with 1 slot + val epochId = epochGenerator.incrementAndGet() + epochIds += epochId + partitionLocationMonitor.addActiveLocationEpoch(newPartitionLocation(epochId)) + assert(partitionLocationMonitor.activeLocationCount == 1) + + // 1 slot is split, nextReserveSlotCount should be 1 + epochIds.foreach { + epochId => + partitionLocationMonitor.receivePartitionSplitOrRevived( + epochId, + Some(StatusCode.SOFT_SPLIT)) + } + assert(partitionLocationMonitor.activeLocationCount == 0) + var nextReserveSlotCount = partitionLocationMonitor.nextReserveSlotCount + assert(nextReserveSlotCount == 1) + + // reserve 1 new slot + for (i <- 0 until nextReserveSlotCount) { + val epochId = epochGenerator.incrementAndGet() + partitionLocationMonitor.addActiveLocationEpoch(newPartitionLocation(epochId)) + epochIds += epochId + } + assert(partitionLocationMonitor.activeLocationCount == 1) + + // 1 new slot is split, nextReserveSlotCount should be 2 + epochIds.foreach { epochId => + partitionLocationMonitor.receivePartitionSplitOrRevived(epochId, Some(StatusCode.SOFT_SPLIT)) + } + assert(partitionLocationMonitor.activeLocationCount == 0) + nextReserveSlotCount = partitionLocationMonitor.nextReserveSlotCount + assert(nextReserveSlotCount == 2) + + // reserve 2 new slots + for (i <- 0 until nextReserveSlotCount) { + val epochId = epochGenerator.incrementAndGet() + partitionLocationMonitor.addActiveLocationEpoch(newPartitionLocation(epochId)) + epochIds += epochId + } + assert(partitionLocationMonitor.activeLocationCount == 2) + + // 2 new slots are split, nextReserveSlotCount should be 4 + // but max active location limit is 3, so nextReserveSlotCount should be 3 + epochIds.foreach { epochId => + partitionLocationMonitor.receivePartitionSplitOrRevived(epochId, Some(StatusCode.SOFT_SPLIT)) + } + assert(partitionLocationMonitor.activeLocationCount == 0) + nextReserveSlotCount = partitionLocationMonitor.nextReserveSlotCount + assert(nextReserveSlotCount == 3) + } + + test("soft split then hard split and failed") { + val conf: CelebornConf = new CelebornConf() + val partitionLocationMonitor = new PartitionLocationMonitor(0, 0, conf, 10) + val epochGenerator = new AtomicInteger + val epochIds = new ArrayBuffer[Int] + + // start with 1 slot + val epochId = epochGenerator.incrementAndGet() + epochIds += epochId + partitionLocationMonitor.addActiveLocationEpoch(newPartitionLocation(epochId)) + assert(partitionLocationMonitor.activeLocationCount == 1) + + // 1 slot is split, nextReserveSlotCount should be 1 + epochIds.foreach { + epochId => + partitionLocationMonitor.receivePartitionSplitOrRevived( + epochId, + Some(StatusCode.SOFT_SPLIT)) + } + assert(partitionLocationMonitor.activeLocationCount == 0) + var nextReserveSlotCount = partitionLocationMonitor.nextReserveSlotCount + assert(nextReserveSlotCount == 1) + + // reserve 1 new slot + for (i <- 0 until nextReserveSlotCount) { + val epochId = epochGenerator.incrementAndGet() + partitionLocationMonitor.addActiveLocationEpoch(newPartitionLocation(epochId)) + epochIds += epochId + } + assert(partitionLocationMonitor.activeLocationCount == 1) + + // epoch 1 is hard split, new activeLocationCount should be 1, current activeLocationCount is 1 so nextReserveSlotCount should be 1 + partitionLocationMonitor.receivePartitionSplitOrRevived(1, Some(StatusCode.HARD_SPLIT)) + assert(partitionLocationMonitor.activeLocationCount == 1) + assert(partitionLocationMonitor.nextReserveSlotCount == 1) + + // reserve 1 new slot + for (i <- 0 until nextReserveSlotCount) { + val epochId = epochGenerator.incrementAndGet() + partitionLocationMonitor.addActiveLocationEpoch(newPartitionLocation(epochId)) + epochIds += epochId + } + assert(partitionLocationMonitor.activeLocationCount == 2) + partitionLocationMonitor.receivePartitionSplitOrRevived( + 2, + Some(StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_PRIMARY)) + assert(partitionLocationMonitor.activeLocationCount == 1) + assert(partitionLocationMonitor.nextReserveSlotCount == 1) + } +} diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/BufferStatusHub.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/BufferStatusHub.java index c674673be4c..a8b607bf25e 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/BufferStatusHub.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/BufferStatusHub.java @@ -21,6 +21,8 @@ import org.apache.commons.lang3.tuple.Pair; +import org.apache.celeborn.common.util.TimeSlidingHub; + public class BufferStatusHub extends TimeSlidingHub { public static class BufferStatusNode implements TimeSlidingHub.TimeSlidingNode { diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TestTimeSlidingHub.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TestTimeSlidingHub.java index 2bca305f1fe..b6739857c5a 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TestTimeSlidingHub.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TestTimeSlidingHub.java @@ -22,6 +22,8 @@ import org.junit.Assert; import org.junit.Test; +import org.apache.celeborn.common.util.TimeSlidingHub; + public class TestTimeSlidingHub { private static class DummyTimeSlidingHub diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/DynamicallySplitPartitionSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/DynamicallySplitPartitionSuite.scala new file mode 100644 index 00000000000..c9ca9fb0074 --- /dev/null +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/DynamicallySplitPartitionSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.cluster + +import java.nio.charset.StandardCharsets +import java.util.concurrent.{Executors, TimeUnit} + +import org.apache.commons.lang3.RandomStringUtils +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl} +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.identity.UserIdentifier +import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.service.deploy.MiniClusterFeature + +class DynamicallySplitPartitionSuite extends AnyFunSuite + with Logging with MiniClusterFeature with BeforeAndAfterAll { + val masterPort = 19099 + + override def beforeAll(): Unit = { + val masterConf = Map( + "celeborn.master.host" -> "localhost", + "celeborn.master.port" -> masterPort.toString) + val workerConf = Map( + "celeborn.master.endpoints" -> s"localhost:$masterPort", + "celeborn.worker.flusher.buffer.size" -> "0") + + logInfo("test initialized , setup Celeborn mini cluster") + setupMiniClusterWithRandomPorts(masterConf, workerConf, 5) + } + + class PushTask(shuffleClient: ShuffleClientImpl, shuffleId: Int, mapId: Int, mapNum: Int) + extends Runnable { + override def run(): Unit = { + val startTime = System.nanoTime() + var i = 0 + // 512 KB/s, 4 tasks, 60s + // expect 10 M soft_split, require a soft split every 5 seconds. + while (i < 60 * 8) { + val DATA = RandomStringUtils.random(64 * 1024).getBytes(StandardCharsets.UTF_8) + shuffleClient.pushData(shuffleId, mapId, 0, 0, DATA, 0, DATA.length, mapNum, 1) + i += 1 + Thread.sleep(125) + } + shuffleClient.mapperEnd(shuffleId, mapId, 0, mapNum) + val endTime = System.nanoTime() + logInfo( + s"PushTask $mapId finished, cost ${TimeUnit.NANOSECONDS.toSeconds(endTime - startTime)} s") + } + } + + test("dynamically split partition") { + val appId = s"dynamically-split-partition-test-${System.currentTimeMillis()}" + val clientConf = new CelebornConf() + .set(CelebornConf.MASTER_ENDPOINTS.key, s"localhost:$masterPort") + .set(CelebornConf.SHUFFLE_PARTITION_SPLIT_THRESHOLD.key, "10M") + .set(CelebornConf.SHUFFLE_PARTITION_SPLIT_MODE.key, "HARD") + .set(CelebornConf.CLIENT_ASYNC_SPLIT_PARTITION_ENABLED.key, "true") + .set(CelebornConf.CLIENT_ACTIVE_FULL_LOCATION_TIME_WINDOW.key, "10s") + .set(CelebornConf.CLIENT_ACTIVE_FULL_LOCATION_INTERVAL_PER_BUCKET.key, "1s") + .set(CelebornConf.CLIENT_EXPECTED_WORKER_SPEED_MB_PER_SECOND.key, "1") + val lifecycleManager = new LifecycleManager(appId, clientConf) + val shuffleClient = new ShuffleClientImpl(appId, clientConf, UserIdentifier("mock", "mock")) + shuffleClient.setupLifecycleManagerRef(lifecycleManager.self) + + val executor = Executors.newFixedThreadPool(4) + val task0 = new PushTask(shuffleClient, 0, 0, 4) + val task1 = new PushTask(shuffleClient, 0, 1, 4) + val task2 = new PushTask(shuffleClient, 0, 2, 4) + val task3 = new PushTask(shuffleClient, 0, 3, 4) + val startTime = System.nanoTime() + executor.submit(task0) + executor.submit(task1) + executor.submit(task2) + executor.submit(task3) + executor.shutdown() + executor.awaitTermination(120, TimeUnit.SECONDS) + val endTime = System.nanoTime() + assert(TimeUnit.NANOSECONDS.toSeconds(endTime - startTime) < 100) + } +} From fc6246c8b9669f51a0e83422dcb1274a10c79974 Mon Sep 17 00:00:00 2001 From: jiang13021 Date: Thu, 15 May 2025 14:49:55 +0800 Subject: [PATCH 3/3] remove getPartitionLocation and fix ut --- .../celeborn/client/DummyShuffleClient.java | 6 --- .../apache/celeborn/client/ReviveManager.java | 2 - .../apache/celeborn/client/ShuffleClient.java | 4 -- .../celeborn/client/ShuffleClientImpl.java | 47 ++++------------- .../LifecycleManagerReserveSlotsSuite.scala | 52 ++++++++++++++----- .../DynamicallySplitPartitionSuite.scala | 14 ++--- .../cluster/PushMergedDataSplitSuite.scala | 36 +++++++++---- 7 files changed, 79 insertions(+), 82 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java index 6b3673b1843..d2f8b1988ce 100644 --- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -170,12 +170,6 @@ public PartitionLocation registerMapPartitionTask( return null; } - @Override - public ConcurrentHashMap getPartitionLocation( - int shuffleId, int numMappers, int numPartitions) { - return reducePartitionMap.get(shuffleId); - } - @Override public PushState getPushState(String mapKey) { return new PushState(conf); diff --git a/client/src/main/java/org/apache/celeborn/client/ReviveManager.java b/client/src/main/java/org/apache/celeborn/client/ReviveManager.java index 81ae4fabf65..45ce158c49d 100644 --- a/client/src/main/java/org/apache/celeborn/client/ReviveManager.java +++ b/client/src/main/java/org/apache/celeborn/client/ReviveManager.java @@ -134,10 +134,8 @@ public void processRequests(int shuffleId, Collection requests, b // Call reviveBatch. Return null means Exception caught or // SHUFFLE_NOT_REGISTERED // Do not use WriterTracerHere because traceInfo is set afterward - long reviveStartTime = System.nanoTime(); Map results = shuffleClient.reviveBatch(shuffleId, mapIds, requestsToSend.values(), urgent); - long reviveCostTime = System.nanoTime() - reviveStartTime; if (results == null) { for (ReviveRequest req : filteredRequests) { req.reviveStatus = StatusCode.REVIVE_FAILED.getValue(); diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 3ef5a3548ee..e1e2ba1e4e1 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -22,7 +22,6 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; import java.util.function.BiFunction; @@ -280,9 +279,6 @@ public abstract CelebornInputStream readPartition( public abstract PartitionLocation registerMapPartitionTask( int shuffleId, int numMappers, int mapId, int attemptId, int partitionId) throws IOException; - public abstract ConcurrentHashMap getPartitionLocation( - int shuffleId, int numMappers, int numPartitions) throws CelebornIOException; - public boolean ensureRegistered(int shuffleId, int numMappers, int numPartitions) { return false; } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index f1041f7422e..5602d425a73 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -664,30 +664,6 @@ public PartitionLocation registerMapPartitionTask( return partitionLocationMap.get(partitionId).get(0); } - @Override - public ConcurrentHashMap getPartitionLocation( - int shuffleId, int numMappers, int numPartitions) throws CelebornIOException { - // TODO only UT related usages, fix later - return null; - // try { - // return reducePartitionMap.computeIfAbsent( - // shuffleId, - // (id) -> { - // try { - // return registerShuffle(shuffleId, numMappers, numPartitions); - // } catch (CelebornIOException e) { - // throw new RuntimeException(e); - // } - // }); - // } catch (RuntimeException e) { - // if (e.getCause() instanceof CelebornIOException) { - // throw (CelebornIOException) e.getCause(); - // } else { - // throw e; - // } - // } - } - @Override public boolean ensureRegistered(int shuffleId, int numMappers, int numPartitions) { if (!locationManager.registered(shuffleId)) { @@ -777,20 +753,15 @@ private ConcurrentHashMap> registerShuffleInter if (StatusCode.SUCCESS.equals(respStatus)) { ConcurrentHashMap> result = JavaUtils.newConcurrentHashMap(); - for (int i = 0; i < response.getPartitionLocationsList().size(); i++) { - Tuple2, List> locations = - PbSerDeUtils.fromPbPackedPartitionLocationsPair( - response.getPackedPartitionLocationsPair()); - for (PartitionLocation location : locations._1) { - pushExcludedWorkers.remove(location.hostAndPushPort()); - if (location.hasPeer()) { - pushExcludedWorkers.remove(location.getPeer().hostAndPushPort()); - } - List list = - result.computeIfAbsent(location.getId(), x -> new ArrayList<>()); - list.add(location); - result.put(location.getId(), list); + Tuple2, List> locations = + PbSerDeUtils.fromPbPackedPartitionLocationsPair( + response.getPackedPartitionLocationsPair()); + for (PartitionLocation location : locations._1) { + pushExcludedWorkers.remove(location.hostAndPushPort()); + if (location.hasPeer()) { + pushExcludedWorkers.remove(location.getPeer().hostAndPushPort()); } + result.computeIfAbsent(location.getId(), x -> new ArrayList<>()).add(location); } return result; } else if (StatusCode.SLOT_NOT_AVAILABLE.equals(respStatus)) { @@ -928,6 +899,7 @@ Map reviveBatch( PbChangeLocationPartitionInfo partitionInfo = response.getPartitionInfo(i); int partitionId = partitionInfo.getPartitionId(); int statusCode = partitionInfo.getStatus(); + results.put(partitionId, statusCode); if (partitionInfo.getOldAvailable()) { PartitionLocation oldLoc = oldLocMap.get(partitionId); // Currently, revive only check if main location available, here won't remove peer loc. @@ -959,7 +931,6 @@ Map reviveBatch( logger.error("SHUFFLE_NOT_REGISTERED!"); return null; } - results.put(partitionId, statusCode); } return results; diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala index ee9317a5e03..08e02bdf389 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala @@ -81,17 +81,21 @@ class LifecycleManagerReserveSlotsSuite extends AnyFunSuite PARTITION_NUM) // find the worker that has at least 2 partitions - val partitionLocationMap1 = - shuffleClient1.getPartitionLocation(SHUFFLE_ID, MAP_NUM, PARTITION_NUM) + val locationManager1 = shuffleClient1.getLocationManager val worker2PartitionIds = mutable.Map.empty[WorkerInfo, ArrayBuffer[Int]] - for (partitionId <- 0 until PARTITION_NUM) { - val partitionLocation = partitionLocationMap1.get(partitionId) - assert(partitionLocation.getEpoch == 0) + (0 until PARTITION_NUM).foreach(partitionId => { + val partitionLocation = locationManager1.getLocationOrReviveAsync( + SHUFFLE_ID, + partitionId, + MAP_ID, + ATTEMPT_ID, + false, + false) worker2PartitionIds .getOrElseUpdate(partitionLocation.getWorker, ArrayBuffer.empty) .append(partitionId) - } + }) val partitions = worker2PartitionIds.values.filter(_.size >= 2).head assert(partitions.length >= 2) @@ -152,23 +156,47 @@ class LifecycleManagerReserveSlotsSuite extends AnyFunSuite } assert( - partitionLocationMap1.get(partitions(0)).getEpoch > 0 + locationManager1.getLocationOrReviveAsync( + SHUFFLE_ID, + 0, + MAP_ID, + ATTEMPT_ID, + false, + false).getEpoch > 0 ) // means partition(0) will be split // push merged data, we expect that partition(0) will be split, while partition(1) will not be split shuffleClient1.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID) shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM) // partition(1) will not be split - assert(partitionLocationMap1.get(partitions(1)).getEpoch == 0) + assert(locationManager1.getLocationOrReviveAsync( + SHUFFLE_ID, + 1, + MAP_ID, + ATTEMPT_ID, + false, + false).getEpoch == 0) val shuffleClient2 = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) shuffleClient2.setupLifecycleManagerRef(lifecycleManager.self) - val partitionLocationMap2 = - shuffleClient2.getPartitionLocation(SHUFFLE_ID, MAP_NUM, PARTITION_NUM) + shuffleClient2.ensureRegistered(SHUFFLE_ID, MAP_NUM, PARTITION_NUM) + val locationManager2 = shuffleClient2.getLocationManager // lifecycleManager response with the latest epoch(epoch of partition(0) is larger than 0 caused by split) - assert(partitionLocationMap2.get(partitions(0)).getEpoch > 0) + assert(locationManager2.getLocationOrReviveAsync( + SHUFFLE_ID, + 0, + MAP_ID, + ATTEMPT_ID, + false, + false).getEpoch > 0) // epoch of partition(1) is 0 without split - assert(partitionLocationMap2.get(partitions(1)).getEpoch == 0) + assert(locationManager2.getLocationOrReviveAsync( + SHUFFLE_ID, + 1, + MAP_ID, + ATTEMPT_ID, + false, + false).getEpoch == 0) } } diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/DynamicallySplitPartitionSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/DynamicallySplitPartitionSuite.scala index c9ca9fb0074..f566974bea1 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/DynamicallySplitPartitionSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/DynamicallySplitPartitionSuite.scala @@ -32,18 +32,14 @@ import org.apache.celeborn.service.deploy.MiniClusterFeature class DynamicallySplitPartitionSuite extends AnyFunSuite with Logging with MiniClusterFeature with BeforeAndAfterAll { - val masterPort = 19099 + var masterEndpoint = "" override def beforeAll(): Unit = { - val masterConf = Map( - "celeborn.master.host" -> "localhost", - "celeborn.master.port" -> masterPort.toString) - val workerConf = Map( - "celeborn.master.endpoints" -> s"localhost:$masterPort", - "celeborn.worker.flusher.buffer.size" -> "0") + val conf = Map("celeborn.worker.flusher.buffer.size" -> "0") logInfo("test initialized , setup Celeborn mini cluster") - setupMiniClusterWithRandomPorts(masterConf, workerConf, 5) + val (master, _) = setupMiniClusterWithRandomPorts(conf, conf, 5) + masterEndpoint = master.conf.get(CelebornConf.MASTER_ENDPOINTS.key) } class PushTask(shuffleClient: ShuffleClientImpl, shuffleId: Int, mapId: Int, mapNum: Int) @@ -69,7 +65,7 @@ class DynamicallySplitPartitionSuite extends AnyFunSuite test("dynamically split partition") { val appId = s"dynamically-split-partition-test-${System.currentTimeMillis()}" val clientConf = new CelebornConf() - .set(CelebornConf.MASTER_ENDPOINTS.key, s"localhost:$masterPort") + .set(CelebornConf.MASTER_ENDPOINTS.key, masterEndpoint) .set(CelebornConf.SHUFFLE_PARTITION_SPLIT_THRESHOLD.key, "10M") .set(CelebornConf.SHUFFLE_PARTITION_SPLIT_MODE.key, "HARD") .set(CelebornConf.CLIENT_ASYNC_SPLIT_PARTITION_ENABLED.key, "true") diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala index cdd0e758e6f..767c4930765 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala @@ -83,15 +83,20 @@ class PushMergedDataSplitSuite extends AnyFunSuite PARTITION_NUM) // find the worker that has at least 2 partitions - val partitionLocationMap = - shuffleClient.getPartitionLocation(SHUFFLE_ID, MAP_NUM, PARTITION_NUM) + val locationManager = shuffleClient.getLocationManager val worker2PartitionIds = mutable.Map.empty[WorkerInfo, ArrayBuffer[Int]] - for (partitionId <- 0 until PARTITION_NUM) { - val partitionLocation = partitionLocationMap.get(partitionId) + (0 until PARTITION_NUM).foreach(partitionId => { + val partitionLocation = locationManager.getLocationOrReviveAsync( + SHUFFLE_ID, + partitionId, + MAP_ID, + ATTEMPT_ID, + false, + false) worker2PartitionIds .getOrElseUpdate(partitionLocation.getWorker, ArrayBuffer.empty) .append(partitionId) - } + }) val partitions = worker2PartitionIds.values.filter(_.size >= 2).head assert(partitions.length >= 2) @@ -147,16 +152,25 @@ class PushMergedDataSplitSuite extends AnyFunSuite PARTITION_NUM) Thread.sleep(5 * 1000) // wait for flush } - assert( - partitionLocationMap.get(partitions(0)).getEpoch > 0 - ) // means partition(0) will be split + assert(locationManager.getLocationOrReviveAsync( + SHUFFLE_ID, + 0, + MAP_ID, + ATTEMPT_ID, + false, + false).getEpoch > 0) // means partition(0) hard_split + // means partition(0) will be split // push merged data, we expect that partition(0) will be split, while partition(1) will not be split shuffleClient.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID) shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM) - assert( - partitionLocationMap.get(partitions(1)).getEpoch == 0 - ) // means partition(1) will not be split + assert(locationManager.getLocationOrReviveAsync( + SHUFFLE_ID, + 1, + MAP_ID, + ATTEMPT_ID, + false, + false).getEpoch == 0) // means partition(1) will not be split } } }