Skip to content

Commit dac1065

Browse files
authored
[#2411] fix(spark): Spill memory corresponding to successfully sent blocks (#2415)
### What changes were proposed in this pull request? As title ### Why are the changes needed? Before this pr, spark client spill more memory than actually did ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? UT
1 parent 3874f15 commit dac1065

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,7 @@ public CompletableFuture<Long> send(AddBlockEvent event) {
100100
putFailedBlockSendTracker(
101101
taskToFailedBlockSendTracker, taskId, result.getFailedBlockSendTracker());
102102
} finally {
103-
Set<Long> succeedBlockIds =
104-
result.getSuccessBlockIds() == null
105-
? Collections.emptySet()
106-
: result.getSuccessBlockIds();
103+
Set<Long> succeedBlockIds = getSucceedBlockIds(result);
107104
for (ShuffleBlockInfo block : shuffleBlockInfoList) {
108105
block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId()));
109106
}
@@ -114,7 +111,9 @@ public CompletableFuture<Long> send(AddBlockEvent event) {
114111
runnable.run();
115112
}
116113
}
114+
Set<Long> succeedBlockIds = getSucceedBlockIds(result);
117115
return shuffleBlockInfoList.stream()
116+
.filter(x -> succeedBlockIds.contains(x.getBlockId()))
118117
.map(x -> x.getFreeMemory())
119118
.reduce((a, b) -> a + b)
120119
.get();
@@ -127,6 +126,13 @@ public CompletableFuture<Long> send(AddBlockEvent event) {
127126
});
128127
}
129128

129+
private Set<Long> getSucceedBlockIds(SendShuffleDataResult result) {
130+
if (result == null || result.getSuccessBlockIds() == null) {
131+
return Collections.emptySet();
132+
}
133+
return result.getSuccessBlockIds();
134+
}
135+
130136
private synchronized void putBlockId(
131137
Map<String, Set<Long>> taskToBlockIds, String taskAttemptId, Set<Long> blockIds) {
132138
if (blockIds == null || blockIds.isEmpty()) {

client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,12 @@ public void spillPartial() {
487487
List<AddBlockEvent> events = wbm.buildBlockEvents(blocks);
488488
for (AddBlockEvent event : events) {
489489
event.getProcessedCallbackChain().stream().forEach(x -> x.run());
490-
sum += event.getShuffleDataInfoList().stream().mapToLong(x -> x.getFreeMemory()).sum();
490+
// simulate: the block for partition 2 send failed
491+
sum +=
492+
event.getShuffleDataInfoList().stream()
493+
.filter(x -> x.getPartitionId() <= 1)
494+
.mapToLong(x -> x.getFreeMemory())
495+
.sum();
491496
}
492497
return Arrays.asList(CompletableFuture.completedFuture(sum));
493498
};
@@ -502,10 +507,11 @@ public void spillPartial() {
502507
wbm.addRecord(1, testKey, testValue);
503508
wbm.addRecord(1, testKey, testValue);
504509
wbm.addRecord(1, testKey, testValue);
510+
wbm.addRecord(2, testKey, testValue);
505511

506512
long releasedSize = wbm.spill(1000, wbm);
507513
assertEquals(64, releasedSize);
508-
assertEquals(96, wbm.getUsedBytes());
514+
assertEquals(128, wbm.getUsedBytes());
509515
assertEquals(0, wbm.getBuffers().keySet().toArray()[0]);
510516
}
511517

0 commit comments

Comments
 (0)