diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 65aa7c815fc42..c25c24cf61139 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -169,21 +169,28 @@ public long getInitialSize() { * Free the memory used by pointer array. */ public void freeMemory() { - if (consumer != null) { - if (array != null) { - consumer.freeArray(array); + LongArray arrayToFree = null; + try { + synchronized (this) { + arrayToFree = array; + + // Set the array to null instead of allocating a new array. Allocating an array could have + // triggered another spill and this method already is called from UnsafeExternalSorter when + // spilling. Attempting to allocate while spilling is dangerous, as we could be holding onto + // a large partially complete allocation, which may prevent other memory from being + // allocated. + // Instead we will allocate the new array when it is necessary. + array = null; + usableCapacity = 0; + + pos = 0; + nullBoundaryPos = 0; + } + } finally { + if (consumer != null && arrayToFree != null) { + consumer.freeArray(arrayToFree); } - - // Set the array to null instead of allocating a new array. Allocating an array could have - // triggered another spill and this method already is called from UnsafeExternalSorter when - // spilling. Attempting to allocate while spilling is dangerous, as we could be holding onto - // a large partially complete allocation, which may prevent other memory from being allocated. - // Instead we will allocate the new array when it is necessary. - array = null; - usableCapacity = 0; } - pos = 0; - nullBoundaryPos = 0; } /** @@ -227,8 +234,10 @@ public void expandPointerArray(LongArray newArray) { pos * 8L); consumer.freeArray(array); } - array = newArray; - usableCapacity = getUsableCapacity(); + synchronized (this) { + array = newArray; + usableCapacity = getUsableCapacity(); + } } /** diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index a612824fed498..2cfd9e1676e3d 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -19,6 +19,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.concurrent.CountDownLatch; import org.apache.spark.unsafe.array.LongArray; import org.junit.jupiter.api.Assertions; @@ -191,4 +192,46 @@ public int compare( assertEquals(0L, memoryManager.cleanUpAllAllocatedMemory()); } + @Test + public void testThreadSafety() throws InterruptedException { + final TestMemoryManager memoryManager =new TestMemoryManager( + new SparkConf().set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false)); + final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager); + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, + taskMemoryManager, + mock(RecordComparator.class), + mock(PrefixComparator.class), + 100, + shouldUseRadixSort()); + + int timeout = 10000; + long start = System.currentTimeMillis(); + while (true) { + sorter.freeMemory(); + CountDownLatch downLatch = new CountDownLatch(1); + Thread thread1 = new Thread(() -> { + sorter.expandPointerArray(consumer.allocateArray(2000)); + downLatch.countDown(); + }); + Thread thread2 = new Thread(() -> { + sorter.freeMemory(); + try { + downLatch.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + sorter.freeMemory(); + }); + thread1.start(); + thread2.start(); + thread1.join(); + thread2.join(); + Assertions.assertEquals(0, memoryManager.getExecutionMemoryUsageForTask(0)); + if (System.currentTimeMillis() - start > timeout) { + break; + } + } + } + }