@@ -23,42 +23,43 @@ use std::{
2323 } ,
2424} ;
2525
26- use jni:: objects:: GlobalRef ;
27-
28- use datafusion:: {
29- common:: { resources_datafusion_err, DataFusionError } ,
30- execution:: memory_pool:: { MemoryPool , MemoryReservation } ,
31- } ;
32-
3326use crate :: {
3427 errors:: CometResult ,
3528 jvm_bridge:: { jni_call, JVMClasses } ,
3629} ;
30+ use datafusion:: {
31+ common:: { resources_datafusion_err, DataFusionError } ,
32+ execution:: memory_pool:: { MemoryPool , MemoryReservation } ,
33+ } ;
34+ use jni:: objects:: GlobalRef ;
35+ use log:: warn;
3736
38- /// A DataFusion `MemoryPool` implementation for Comet. Internally this is
39- /// implemented via delegating calls to [`crate::jvm_bridge::CometTaskMemoryManager`].
40- pub struct CometMemoryPool {
37+ /// A DataFusion `MemoryPool` implementation for Comet that delegates to
38+ /// Spark's off-heap executor memory pool via JNI by calling
39+ /// [`crate::jvm_bridge::CometTaskMemoryManager`].
40+ pub struct CometUnifiedMemoryPool {
4141 task_memory_manager_handle : Arc < GlobalRef > ,
4242 used : AtomicUsize ,
4343}
4444
45- impl Debug for CometMemoryPool {
45+ impl Debug for CometUnifiedMemoryPool {
4646 fn fmt ( & self , f : & mut Formatter < ' _ > ) -> FmtResult {
47- f. debug_struct ( "CometMemoryPool " )
47+ f. debug_struct ( "CometUnifiedMemoryPool " )
4848 . field ( "used" , & self . used . load ( Relaxed ) )
4949 . finish ( )
5050 }
5151}
5252
53- impl CometMemoryPool {
54- pub fn new ( task_memory_manager_handle : Arc < GlobalRef > ) -> CometMemoryPool {
53+ impl CometUnifiedMemoryPool {
54+ pub fn new ( task_memory_manager_handle : Arc < GlobalRef > ) -> CometUnifiedMemoryPool {
5555 Self {
5656 task_memory_manager_handle,
5757 used : AtomicUsize :: new ( 0 ) ,
5858 }
5959 }
6060
61- fn acquire ( & self , additional : usize ) -> CometResult < i64 > {
61+ /// Request memory from Spark's off-heap memory pool via JNI
62+ fn acquire_from_spark ( & self , additional : usize ) -> CometResult < i64 > {
6263 let mut env = JVMClasses :: get_env ( ) ?;
6364 let handle = self . task_memory_manager_handle . as_obj ( ) ;
6465 unsafe {
@@ -67,7 +68,8 @@ impl CometMemoryPool {
6768 }
6869 }
6970
70- fn release ( & self , size : usize ) -> CometResult < ( ) > {
71+ /// Release memory to Spark's off-heap memory pool via JNI
72+ fn release_to_spark ( & self , size : usize ) -> CometResult < ( ) > {
7173 let mut env = JVMClasses :: get_env ( ) ?;
7274 let handle = self . task_memory_manager_handle . as_obj ( ) ;
7375 unsafe {
@@ -76,37 +78,42 @@ impl CometMemoryPool {
7678 }
7779}
7880
79- impl Drop for CometMemoryPool {
81+ impl Drop for CometUnifiedMemoryPool {
8082 fn drop ( & mut self ) {
8183 let used = self . used . load ( Relaxed ) ;
8284 if used != 0 {
83- log :: warn!( "CometMemoryPool dropped with {used} bytes still reserved" ) ;
85+ warn ! ( "CometUnifiedMemoryPool dropped with {used} bytes still reserved" ) ;
8486 }
8587 }
8688}
8789
88- unsafe impl Send for CometMemoryPool { }
89- unsafe impl Sync for CometMemoryPool { }
90+ unsafe impl Send for CometUnifiedMemoryPool { }
91+ unsafe impl Sync for CometUnifiedMemoryPool { }
9092
91- impl MemoryPool for CometMemoryPool {
93+ impl MemoryPool for CometUnifiedMemoryPool {
9294 fn grow ( & self , reservation : & MemoryReservation , additional : usize ) {
9395 self . try_grow ( reservation, additional) . unwrap ( ) ;
9496 }
9597
9698 fn shrink ( & self , _: & MemoryReservation , size : usize ) {
97- self . release ( size)
99+ self . release_to_spark ( size)
98100 . unwrap_or_else ( |_| panic ! ( "Failed to release {size} bytes" ) ) ;
99- self . used . fetch_sub ( size, Relaxed ) ;
101+ if let Err ( prev) = self
102+ . used
103+ . fetch_update ( Relaxed , Relaxed , |old| old. checked_sub ( size) )
104+ {
105+ panic ! ( "overflow when releasing {size} of {prev} bytes" ) ;
106+ }
100107 }
101108
102109 fn try_grow ( & self , _: & MemoryReservation , additional : usize ) -> Result < ( ) , DataFusionError > {
103110 if additional > 0 {
104- let acquired = self . acquire ( additional) ?;
111+ let acquired = self . acquire_from_spark ( additional) ?;
105112 // If the number of bytes we acquired is less than the requested, return an error,
106113 // and hopefully will trigger spilling from the caller side.
107114 if acquired < additional as i64 {
108115 // Release the acquired bytes before throwing error
109- self . release ( acquired as usize ) ?;
116+ self . release_to_spark ( acquired as usize ) ?;
110117
111118 return Err ( resources_datafusion_err ! (
112119 "Failed to acquire {} bytes, only got {}. Reserved: {}" ,
@@ -115,7 +122,16 @@ impl MemoryPool for CometMemoryPool {
115122 self . reserved( )
116123 ) ) ;
117124 }
118- self . used . fetch_add ( acquired as usize , Relaxed ) ;
125+ if let Err ( prev) = self
126+ . used
127+ . fetch_update ( Relaxed , Relaxed , |old| old. checked_add ( acquired as usize ) )
128+ {
129+ return Err ( resources_datafusion_err ! (
130+ "Failed to acquire {} bytes due to overflow. Reserved: {}" ,
131+ additional,
132+ prev
133+ ) ) ;
134+ }
119135 }
120136 Ok ( ( ) )
121137 }
0 commit comments