Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 97 additions & 93 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,25 @@ import org.apache.comet.Tracing.withTrace
import org.apache.comet.vector.NativeUtil

/**
* An iterator class used to execute Comet native query. It takes an input iterator which comes
* from Comet Scan and is expected to produce batches of Arrow Arrays. During consuming this
* iterator, it will consume input iterator and pass Arrow Arrays to Comet native engine by
* addresses. Even after the end of input iterator, this iterator still possibly continues
* executing native query as there might be blocking operators such as Sort, Aggregate. The API
* `hasNext` can be used to check if it is the end of this iterator (i.e. the native query is
* done).
* Iterator for calling native `executePlan` and importing the resulting arrays via Arrow FFI.
*
* @param id
* Unique identifier for this execution context (used for native plan tracking)
* @param inputs
* The input iterators producing sequence of batches of Arrow Arrays.
* Sequence of input ColumnarBatch iterators from upstream Spark operators
* @param numOutputCols
* Number of columns in the output schema
* @param protobufQueryPlan
* The serialized bytes of Spark execution plan.
* Serialized Spark physical plan as protocol buffer bytes
* @param nativeMetrics
* Metrics collection node for native execution statistics
* @param numParts
* The number of partitions.
* Total number of partitions in the query
* @param partitionIndex
* The index of the partition.
* Zero-based index of the partition this iterator processes
*
* @see
* [[org.apache.comet.vector.NativeUtil]] for Arrow array import/export utilities
*/
class CometExecIterator(
val id: Long,
Expand Down Expand Up @@ -109,31 +112,60 @@ class CometExecIterator(
private var currentBatch: ColumnarBatch = null
private var closed: Boolean = false

private def getMemoryLimitPerTask(conf: SparkConf): Long = {
val numCores = numDriverOrExecutorCores(conf).toFloat
val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf)
val coresPerTask = conf.get("spark.task.cpus", "1").toFloat
// example 16GB maxMemory * 16 cores with 4 cores per task results
// in memory_limit_per_task = 16 GB * 4 / 16 = 16 GB / 4 = 4GB
val limit = (maxMemory.toFloat * coresPerTask / numCores).toLong
logInfo(
s"Calculated per-task memory limit of $limit ($maxMemory * $coresPerTask / $numCores)")
limit
/**
* Checks if there are more batches available from the native execution engine.
*
* @return
* true if more batches are available, false if execution is complete
*/
override def hasNext: Boolean = {
if (closed) return false

if (nextBatch.isDefined) {
return true
}

// Close previous batch if any.
// This is to guarantee safety at the native side before we overwrite the buffer memory
// shared across batches in the native side.
if (prevBatch != null) {
prevBatch.close()
prevBatch = null
}

nextBatch = getNextBatch

if (nextBatch.isEmpty) {
close()
false
} else {
true
}
}

private def numDriverOrExecutorCores(conf: SparkConf): Int = {
def convertToInt(threads: String): Int = {
if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt
/**
* Returns the next ColumnarBatch from native execution.
*
* @return
* ColumnarBatch containing Arrow arrays with transferred ownership
* @throws NoSuchElementException
* if no more elements are available (call hasNext first)
*/
override def next(): ColumnarBatch = {
if (currentBatch != null) {
// Eagerly release Arrow Arrays in the previous batch
currentBatch.close()
currentBatch = null
}
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
val master = conf.get("spark.master")
master match {
case "local" => 1
case LOCAL_N_REGEX(threads) => convertToInt(threads)
case LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads)
case _ => conf.get("spark.executor.cores", "1").toInt

if (nextBatch.isEmpty && !hasNext) {
throw new NoSuchElementException("No more element")
}

currentBatch = nextBatch.get
prevBatch = currentBatch
nextBatch = None
currentBatch
}

private def getNextBatch: Option[ColumnarBatch] = {
Expand Down Expand Up @@ -185,48 +217,12 @@ class CometExecIterator(
}
}

override def hasNext: Boolean = {
if (closed) return false

if (nextBatch.isDefined) {
return true
}

// Close previous batch if any.
// This is to guarantee safety at the native side before we overwrite the buffer memory
// shared across batches in the native side.
if (prevBatch != null) {
prevBatch.close()
prevBatch = null
}

nextBatch = getNextBatch

if (nextBatch.isEmpty) {
close()
false
} else {
true
}
}

override def next(): ColumnarBatch = {
if (currentBatch != null) {
// Eagerly release Arrow Arrays in the previous batch
currentBatch.close()
currentBatch = null
}

if (nextBatch.isEmpty && !hasNext) {
throw new NoSuchElementException("No more element")
}

currentBatch = nextBatch.get
prevBatch = currentBatch
nextBatch = None
currentBatch
}

/**
* Releases all resources associated with this iterator including native memory and JNI handles.
*
* Note that this method can be called both from CometExecIterotor as well as from other Spark
* threads, so needs be a synchronized method.
*/
def close(): Unit = synchronized {
if (!closed) {
if (currentBatch != null) {
Expand All @@ -240,29 +236,37 @@ class CometExecIterator(
traceMemoryUsage()
}

// The allocator thoughts the exported ArrowArray and ArrowSchema structs are not released,
// so it will report:
// Caused by: java.lang.IllegalStateException: Memory was leaked by query.
// Memory leaked: (516) Allocator(ROOT) 0/516/808/9223372036854775807 (res/actual/peak/limit)
// Suspect this seems a false positive leak, because there is no reported memory leak at JVM
// when profiling. `allocator` reports a leak because it calculates the accumulated number
// of memory allocated for ArrowArray and ArrowSchema. But these exported ones will be
// released in native side later.
// More to clarify it. For ArrowArray and ArrowSchema, Arrow will put a release field into the
// memory region which is a callback function pointer (C function) that could be called to
// release these structs in native code too. Once we wrap their memory addresses at native
// side using FFI ArrowArray and ArrowSchema, and drop them later, the callback function will
// be called to release the memory.
// But at JVM, the allocator doesn't know about this fact so it still keeps the accumulated
// number.
// Tried to manually do `release` and `close` that can make the allocator happy, but it will
// cause JVM runtime failure.

// allocator.close()
Comment on lines -243 to -261
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment refers to an allocator that no longer exists so it doesn't seem very useful to keep around

closed = true
}
}

private def getMemoryLimitPerTask(conf: SparkConf): Long = {
val numCores = numDriverOrExecutorCores(conf).toFloat
val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf)
val coresPerTask = conf.get("spark.task.cpus", "1").toFloat
// example 16GB maxMemory * 16 cores with 4 cores per task results
// in memory_limit_per_task = 16 GB * 4 / 16 = 16 GB / 4 = 4GB
val limit = (maxMemory.toFloat * coresPerTask / numCores).toLong
logInfo(
s"Calculated per-task memory limit of $limit ($maxMemory * $coresPerTask / $numCores)")
limit
}

private def numDriverOrExecutorCores(conf: SparkConf): Int = {
def convertToInt(threads: String): Int = {
if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt
}
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
val master = conf.get("spark.master")
master match {
case "local" => 1
case LOCAL_N_REGEX(threads) => convertToInt(threads)
case LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads)
case _ => conf.get("spark.executor.cores", "1").toInt
}
}

private def traceMemoryUsage(): Unit = {
nativeLib.logMemoryUsage("jvm_heapUsed", memoryMXBean.getHeapMemoryUsage.getUsed)
val totalTaskMemory = cometTaskMemoryManager.internal.getMemoryConsumptionForThisTask
Expand Down
Loading