1717
1818//! Define JNI APIs which can be called from Java/Scala.
1919
20- use super :: { serde, utils:: SparkArrowConvert , CometMemoryPool } ;
20+ use super :: { serde, utils:: SparkArrowConvert } ;
2121use arrow:: array:: RecordBatch ;
2222use arrow:: datatypes:: DataType as ArrowDataType ;
23- use datafusion:: execution:: memory_pool:: {
24- FairSpillPool , GreedyMemoryPool , MemoryPool , TrackConsumersPool , UnboundedMemoryPool ,
25- } ;
23+ use datafusion:: execution:: memory_pool:: MemoryPool ;
2624use datafusion:: {
2725 execution:: { disk_manager:: DiskManagerConfig , runtime_env:: RuntimeEnv } ,
2826 physical_plan:: { display:: DisplayableExecutionPlan , SendableRecordBatchStream } ,
@@ -40,7 +38,7 @@ use jni::{
4038} ;
4139use std:: path:: PathBuf ;
4240use std:: time:: { Duration , Instant } ;
43- use std:: { collections :: HashMap , sync:: Arc , task:: Poll } ;
41+ use std:: { sync:: Arc , task:: Poll } ;
4442
4543use crate :: {
4644 errors:: { try_unwrap_or_throw, CometError , CometResult } ,
@@ -60,16 +58,16 @@ use jni::{
6058 objects:: GlobalRef ,
6159 sys:: { jboolean, jdouble, jintArray, jobjectArray, jstring} ,
6260} ;
63- use std:: num:: NonZeroUsize ;
64- use std:: sync:: Mutex ;
6561use tokio:: runtime:: Runtime ;
6662
67- use crate :: execution:: fair_memory_pool:: CometFairMemoryPool ;
63+ use crate :: execution:: memory_pools:: {
64+ create_memory_pool, handle_task_shared_pool_release, parse_memory_pool_config, MemoryPoolConfig ,
65+ } ;
6866use crate :: execution:: operators:: ScanExec ;
6967use crate :: execution:: shuffle:: { read_ipc_compressed, CompressionCodec } ;
7068use crate :: execution:: spark_plan:: SparkPlan ;
7169use log:: info;
72- use once_cell:: sync:: { Lazy , OnceCell } ;
70+ use once_cell:: sync:: Lazy ;
7371
7472static TOKIO_RUNTIME : Lazy < Runtime > = Lazy :: new ( || {
7573 let mut builder = tokio:: runtime:: Builder :: new_multi_thread ( ) ;
@@ -130,51 +128,6 @@ struct ExecutionContext {
130128 pub memory_pool_config : MemoryPoolConfig ,
131129}
132130
133- #[ derive( PartialEq , Eq ) ]
134- enum MemoryPoolType {
135- Unified ,
136- FairUnified ,
137- Greedy ,
138- FairSpill ,
139- GreedyTaskShared ,
140- FairSpillTaskShared ,
141- GreedyGlobal ,
142- FairSpillGlobal ,
143- Unbounded ,
144- }
145-
146- struct MemoryPoolConfig {
147- pool_type : MemoryPoolType ,
148- pool_size : usize ,
149- }
150-
151- impl MemoryPoolConfig {
152- fn new ( pool_type : MemoryPoolType , pool_size : usize ) -> Self {
153- Self {
154- pool_type,
155- pool_size,
156- }
157- }
158- }
159-
160- /// The per-task memory pools keyed by task attempt id.
161- static TASK_SHARED_MEMORY_POOLS : Lazy < Mutex < HashMap < i64 , PerTaskMemoryPool > > > =
162- Lazy :: new ( || Mutex :: new ( HashMap :: new ( ) ) ) ;
163-
164- struct PerTaskMemoryPool {
165- memory_pool : Arc < dyn MemoryPool > ,
166- num_plans : usize ,
167- }
168-
169- impl PerTaskMemoryPool {
170- fn new ( memory_pool : Arc < dyn MemoryPool > ) -> Self {
171- Self {
172- memory_pool,
173- num_plans : 0 ,
174- }
175- }
176- }
177-
178131/// Accept serialized query plan and return the address of the native query plan.
179132/// # Safety
180133/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
@@ -321,134 +274,6 @@ fn prepare_datafusion_session_context(
321274 Ok ( session_ctx)
322275}
323276
324- fn parse_memory_pool_config (
325- off_heap_mode : bool ,
326- memory_pool_type : String ,
327- memory_limit : i64 ,
328- memory_limit_per_task : i64 ,
329- ) -> CometResult < MemoryPoolConfig > {
330- let pool_size = memory_limit as usize ;
331- let memory_pool_config = if off_heap_mode {
332- match memory_pool_type. as_str ( ) {
333- "fair_unified" => MemoryPoolConfig :: new ( MemoryPoolType :: FairUnified , pool_size) ,
334- "default" | "unified" => {
335- // the `unified` memory pool interacts with Spark's memory pool to allocate
336- // memory therefore does not need a size to be explicitly set. The pool size
337- // shared with Spark is set by `spark.memory.offHeap.size`.
338- MemoryPoolConfig :: new ( MemoryPoolType :: Unified , 0 )
339- }
340- _ => {
341- return Err ( CometError :: Config ( format ! (
342- "Unsupported memory pool type for off-heap mode: {}" ,
343- memory_pool_type
344- ) ) )
345- }
346- }
347- } else {
348- // Use the memory pool from DF
349- let pool_size_per_task = memory_limit_per_task as usize ;
350- match memory_pool_type. as_str ( ) {
351- "fair_spill_task_shared" => {
352- MemoryPoolConfig :: new ( MemoryPoolType :: FairSpillTaskShared , pool_size_per_task)
353- }
354- "default" | "greedy_task_shared" => {
355- MemoryPoolConfig :: new ( MemoryPoolType :: GreedyTaskShared , pool_size_per_task)
356- }
357- "fair_spill_global" => {
358- MemoryPoolConfig :: new ( MemoryPoolType :: FairSpillGlobal , pool_size)
359- }
360- "greedy_global" => MemoryPoolConfig :: new ( MemoryPoolType :: GreedyGlobal , pool_size) ,
361- "fair_spill" => MemoryPoolConfig :: new ( MemoryPoolType :: FairSpill , pool_size_per_task) ,
362- "greedy" => MemoryPoolConfig :: new ( MemoryPoolType :: Greedy , pool_size_per_task) ,
363- "unbounded" => MemoryPoolConfig :: new ( MemoryPoolType :: Unbounded , 0 ) ,
364- _ => {
365- return Err ( CometError :: Config ( format ! (
366- "Unsupported memory pool type for on-heap mode: {}" ,
367- memory_pool_type
368- ) ) )
369- }
370- }
371- } ;
372- Ok ( memory_pool_config)
373- }
374-
375- fn create_memory_pool (
376- memory_pool_config : & MemoryPoolConfig ,
377- comet_task_memory_manager : Arc < GlobalRef > ,
378- task_attempt_id : i64 ,
379- ) -> Arc < dyn MemoryPool > {
380- const NUM_TRACKED_CONSUMERS : usize = 10 ;
381- match memory_pool_config. pool_type {
382- MemoryPoolType :: Unified => {
383- // Set Comet memory pool for native
384- let memory_pool = CometMemoryPool :: new ( comet_task_memory_manager) ;
385- Arc :: new ( TrackConsumersPool :: new (
386- memory_pool,
387- NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
388- ) )
389- }
390- MemoryPoolType :: FairUnified => {
391- // Set Comet fair memory pool for native
392- let memory_pool =
393- CometFairMemoryPool :: new ( comet_task_memory_manager, memory_pool_config. pool_size ) ;
394- Arc :: new ( TrackConsumersPool :: new (
395- memory_pool,
396- NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
397- ) )
398- }
399- MemoryPoolType :: Greedy => Arc :: new ( TrackConsumersPool :: new (
400- GreedyMemoryPool :: new ( memory_pool_config. pool_size ) ,
401- NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
402- ) ) ,
403- MemoryPoolType :: FairSpill => Arc :: new ( TrackConsumersPool :: new (
404- FairSpillPool :: new ( memory_pool_config. pool_size ) ,
405- NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
406- ) ) ,
407- MemoryPoolType :: GreedyGlobal => {
408- static GLOBAL_MEMORY_POOL_GREEDY : OnceCell < Arc < dyn MemoryPool > > = OnceCell :: new ( ) ;
409- let memory_pool = GLOBAL_MEMORY_POOL_GREEDY . get_or_init ( || {
410- Arc :: new ( TrackConsumersPool :: new (
411- GreedyMemoryPool :: new ( memory_pool_config. pool_size ) ,
412- NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
413- ) )
414- } ) ;
415- Arc :: clone ( memory_pool)
416- }
417- MemoryPoolType :: FairSpillGlobal => {
418- static GLOBAL_MEMORY_POOL_FAIR : OnceCell < Arc < dyn MemoryPool > > = OnceCell :: new ( ) ;
419- let memory_pool = GLOBAL_MEMORY_POOL_FAIR . get_or_init ( || {
420- Arc :: new ( TrackConsumersPool :: new (
421- FairSpillPool :: new ( memory_pool_config. pool_size ) ,
422- NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
423- ) )
424- } ) ;
425- Arc :: clone ( memory_pool)
426- }
427- MemoryPoolType :: GreedyTaskShared | MemoryPoolType :: FairSpillTaskShared => {
428- let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS . lock ( ) . unwrap ( ) ;
429- let per_task_memory_pool =
430- memory_pool_map. entry ( task_attempt_id) . or_insert_with ( || {
431- let pool: Arc < dyn MemoryPool > =
432- if memory_pool_config. pool_type == MemoryPoolType :: GreedyTaskShared {
433- Arc :: new ( TrackConsumersPool :: new (
434- GreedyMemoryPool :: new ( memory_pool_config. pool_size ) ,
435- NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
436- ) )
437- } else {
438- Arc :: new ( TrackConsumersPool :: new (
439- FairSpillPool :: new ( memory_pool_config. pool_size ) ,
440- NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
441- ) )
442- } ;
443- PerTaskMemoryPool :: new ( pool)
444- } ) ;
445- per_task_memory_pool. num_plans += 1 ;
446- Arc :: clone ( & per_task_memory_pool. memory_pool )
447- }
448- MemoryPoolType :: Unbounded => Arc :: new ( UnboundedMemoryPool :: default ( ) ) ,
449- }
450- }
451-
452277/// Prepares arrow arrays for output.
453278fn prepare_output (
454279 env : & mut JNIEnv ,
@@ -643,22 +468,11 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
643468 // Update metrics
644469 update_metrics ( & mut env, execution_context) ?;
645470
646- if execution_context. memory_pool_config . pool_type == MemoryPoolType :: FairSpillTaskShared
647- || execution_context. memory_pool_config . pool_type == MemoryPoolType :: GreedyTaskShared
648- {
649- // Decrement the number of native plans using the per-task shared memory pool, and
650- // remove the memory pool if the released native plan is the last native plan using it.
651- let task_attempt_id = execution_context. task_attempt_id ;
652- let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS . lock ( ) . unwrap ( ) ;
653- if let Some ( per_task_memory_pool) = memory_pool_map. get_mut ( & task_attempt_id) {
654- per_task_memory_pool. num_plans -= 1 ;
655- if per_task_memory_pool. num_plans == 0 {
656- // Drop the memory pool from the per-task memory pool map if there are no
657- // more native plans using it.
658- memory_pool_map. remove ( & task_attempt_id) ;
659- }
660- }
661- }
471+ handle_task_shared_pool_release (
472+ execution_context. memory_pool_config . pool_type ,
473+ execution_context. task_attempt_id ,
474+ ) ;
475+
662476 let _: Box < ExecutionContext > = Box :: from_raw ( execution_context) ;
663477 Ok ( ( ) )
664478 } )
0 commit comments