Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,21 +169,28 @@ public long getInitialSize() {
* Free the memory used by pointer array.
*/
public void freeMemory() {
if (consumer != null) {
if (array != null) {
consumer.freeArray(array);
LongArray arrayToFree = null;
try {
synchronized (this) {
arrayToFree = array;

// Set the array to null instead of allocating a new array. Allocating an array could have
// triggered another spill and this method already is called from UnsafeExternalSorter when
// spilling. Attempting to allocate while spilling is dangerous, as we could be holding onto
// a large partially complete allocation, which may prevent other memory from being
// allocated.
// Instead we will allocate the new array when it is necessary.
array = null;
usableCapacity = 0;

pos = 0;
nullBoundaryPos = 0;
}
} finally {
if (consumer != null && arrayToFree != null) {
consumer.freeArray(arrayToFree);
}

// Set the array to null instead of allocating a new array. Allocating an array could have
// triggered another spill and this method already is called from UnsafeExternalSorter when
// spilling. Attempting to allocate while spilling is dangerous, as we could be holding onto
// a large partially complete allocation, which may prevent other memory from being allocated.
// Instead we will allocate the new array when it is necessary.
array = null;
usableCapacity = 0;
}
pos = 0;
nullBoundaryPos = 0;
}

/**
Expand Down Expand Up @@ -227,8 +234,10 @@ public void expandPointerArray(LongArray newArray) {
pos * 8L);
consumer.freeArray(array);
}
array = newArray;
usableCapacity = getUsableCapacity();
synchronized (this) {
array = newArray;
usableCapacity = getUsableCapacity();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;

import org.apache.spark.unsafe.array.LongArray;
import org.junit.jupiter.api.Assertions;
Expand Down Expand Up @@ -191,4 +192,46 @@ public int compare(
assertEquals(0L, memoryManager.cleanUpAllAllocatedMemory());
}

@Test
public void testThreadSafety() throws InterruptedException {
final TestMemoryManager memoryManager =new TestMemoryManager(
new SparkConf().set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false));
final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager);
UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer,
taskMemoryManager,
mock(RecordComparator.class),
mock(PrefixComparator.class),
100,
shouldUseRadixSort());

int timeout = 10000;
long start = System.currentTimeMillis();
while (true) {
sorter.freeMemory();
CountDownLatch downLatch = new CountDownLatch(1);
Thread thread1 = new Thread(() -> {
sorter.expandPointerArray(consumer.allocateArray(2000));
downLatch.countDown();
});
Thread thread2 = new Thread(() -> {
sorter.freeMemory();
try {
downLatch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
sorter.freeMemory();
});
thread1.start();
thread2.start();
thread1.join();
thread2.join();
Assertions.assertEquals(0, memoryManager.getExecutionMemoryUsageForTask(0));
if (System.currentTimeMillis() - start > timeout) {
break;
}
}
}

}