diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index c7504f3079..a6393156ec 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -35,22 +35,25 @@ import org.apache.comet.Tracing.withTrace import org.apache.comet.vector.NativeUtil /** - * An iterator class used to execute Comet native query. It takes an input iterator which comes - * from Comet Scan and is expected to produce batches of Arrow Arrays. During consuming this - * iterator, it will consume input iterator and pass Arrow Arrays to Comet native engine by - * addresses. Even after the end of input iterator, this iterator still possibly continues - * executing native query as there might be blocking operators such as Sort, Aggregate. The API - * `hasNext` can be used to check if it is the end of this iterator (i.e. the native query is - * done). + * Iterator for calling native `executePlan` and importing the resulting arrays via Arrow FFI. * + * @param id + * Unique identifier for this execution context (used for native plan tracking) * @param inputs - * The input iterators producing sequence of batches of Arrow Arrays. + * Sequence of input ColumnarBatch iterators from upstream Spark operators + * @param numOutputCols + * Number of columns in the output schema * @param protobufQueryPlan - * The serialized bytes of Spark execution plan. + * Serialized Spark physical plan as protocol buffer bytes + * @param nativeMetrics + * Metrics collection node for native execution statistics * @param numParts - * The number of partitions. + * Total number of partitions in the query * @param partitionIndex - * The index of the partition. + * Zero-based index of the partition this iterator processes + * + * @see + * [[org.apache.comet.vector.NativeUtil]] for Arrow array import/export utilities */ class CometExecIterator( val id: Long, @@ -109,31 +112,60 @@ class CometExecIterator( private var currentBatch: ColumnarBatch = null private var closed: Boolean = false - private def getMemoryLimitPerTask(conf: SparkConf): Long = { - val numCores = numDriverOrExecutorCores(conf).toFloat - val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf) - val coresPerTask = conf.get("spark.task.cpus", "1").toFloat - // example 16GB maxMemory * 16 cores with 4 cores per task results - // in memory_limit_per_task = 16 GB * 4 / 16 = 16 GB / 4 = 4GB - val limit = (maxMemory.toFloat * coresPerTask / numCores).toLong - logInfo( - s"Calculated per-task memory limit of $limit ($maxMemory * $coresPerTask / $numCores)") - limit + /** + * Checks if there are more batches available from the native execution engine. + * + * @return + * true if more batches are available, false if execution is complete + */ + override def hasNext: Boolean = { + if (closed) return false + + if (nextBatch.isDefined) { + return true + } + + // Close previous batch if any. + // This is to guarantee safety at the native side before we overwrite the buffer memory + // shared across batches in the native side. + if (prevBatch != null) { + prevBatch.close() + prevBatch = null + } + + nextBatch = getNextBatch + + if (nextBatch.isEmpty) { + close() + false + } else { + true + } } - private def numDriverOrExecutorCores(conf: SparkConf): Int = { - def convertToInt(threads: String): Int = { - if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt + /** + * Returns the next ColumnarBatch from native execution. + * + * @return + * ColumnarBatch containing Arrow arrays with transferred ownership + * @throws NoSuchElementException + * if no more elements are available (call hasNext first) + */ + override def next(): ColumnarBatch = { + if (currentBatch != null) { + // Eagerly release Arrow Arrays in the previous batch + currentBatch.close() + currentBatch = null } - val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r - val master = conf.get("spark.master") - master match { - case "local" => 1 - case LOCAL_N_REGEX(threads) => convertToInt(threads) - case LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads) - case _ => conf.get("spark.executor.cores", "1").toInt + + if (nextBatch.isEmpty && !hasNext) { + throw new NoSuchElementException("No more element") } + + currentBatch = nextBatch.get + prevBatch = currentBatch + nextBatch = None + currentBatch } private def getNextBatch: Option[ColumnarBatch] = { @@ -185,48 +217,12 @@ class CometExecIterator( } } - override def hasNext: Boolean = { - if (closed) return false - - if (nextBatch.isDefined) { - return true - } - - // Close previous batch if any. - // This is to guarantee safety at the native side before we overwrite the buffer memory - // shared across batches in the native side. - if (prevBatch != null) { - prevBatch.close() - prevBatch = null - } - - nextBatch = getNextBatch - - if (nextBatch.isEmpty) { - close() - false - } else { - true - } - } - - override def next(): ColumnarBatch = { - if (currentBatch != null) { - // Eagerly release Arrow Arrays in the previous batch - currentBatch.close() - currentBatch = null - } - - if (nextBatch.isEmpty && !hasNext) { - throw new NoSuchElementException("No more element") - } - - currentBatch = nextBatch.get - prevBatch = currentBatch - nextBatch = None - currentBatch - } - + /** + * Releases all resources associated with this iterator including native memory and JNI handles. + * + * Note that this method can be called both from CometExecIterotor as well as from other Spark + * threads, so needs be a synchronized method. + */ def close(): Unit = synchronized { if (!closed) { if (currentBatch != null) { @@ -240,29 +236,37 @@ class CometExecIterator( traceMemoryUsage() } - // The allocator thoughts the exported ArrowArray and ArrowSchema structs are not released, - // so it will report: - // Caused by: java.lang.IllegalStateException: Memory was leaked by query. - // Memory leaked: (516) Allocator(ROOT) 0/516/808/9223372036854775807 (res/actual/peak/limit) - // Suspect this seems a false positive leak, because there is no reported memory leak at JVM - // when profiling. `allocator` reports a leak because it calculates the accumulated number - // of memory allocated for ArrowArray and ArrowSchema. But these exported ones will be - // released in native side later. - // More to clarify it. For ArrowArray and ArrowSchema, Arrow will put a release field into the - // memory region which is a callback function pointer (C function) that could be called to - // release these structs in native code too. Once we wrap their memory addresses at native - // side using FFI ArrowArray and ArrowSchema, and drop them later, the callback function will - // be called to release the memory. - // But at JVM, the allocator doesn't know about this fact so it still keeps the accumulated - // number. - // Tried to manually do `release` and `close` that can make the allocator happy, but it will - // cause JVM runtime failure. - - // allocator.close() closed = true } } + private def getMemoryLimitPerTask(conf: SparkConf): Long = { + val numCores = numDriverOrExecutorCores(conf).toFloat + val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf) + val coresPerTask = conf.get("spark.task.cpus", "1").toFloat + // example 16GB maxMemory * 16 cores with 4 cores per task results + // in memory_limit_per_task = 16 GB * 4 / 16 = 16 GB / 4 = 4GB + val limit = (maxMemory.toFloat * coresPerTask / numCores).toLong + logInfo( + s"Calculated per-task memory limit of $limit ($maxMemory * $coresPerTask / $numCores)") + limit + } + + private def numDriverOrExecutorCores(conf: SparkConf): Int = { + def convertToInt(threads: String): Int = { + if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt + } + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r + val master = conf.get("spark.master") + master match { + case "local" => 1 + case LOCAL_N_REGEX(threads) => convertToInt(threads) + case LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads) + case _ => conf.get("spark.executor.cores", "1").toInt + } + } + private def traceMemoryUsage(): Unit = { nativeLib.logMemoryUsage("jvm_heapUsed", memoryMXBean.getHeapMemoryUsage.getUsed) val totalTaskMemory = cometTaskMemoryManager.internal.getMemoryConsumptionForThisTask