Skip to content

Commit f3fbd93

Browse files
authored
chore: Use checked operations when growing or shrinking unified memory pool (#2455)
1 parent b4f5b5e commit f3fbd93

File tree

2 files changed

+44
-28
lines changed

2 files changed

+44
-28
lines changed

native/core/src/execution/memory_pools/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use jni::objects::GlobalRef;
2828
use once_cell::sync::OnceCell;
2929
use std::num::NonZeroUsize;
3030
use std::sync::Arc;
31-
use unified_pool::CometMemoryPool;
31+
use unified_pool::CometUnifiedMemoryPool;
3232

3333
pub(crate) use config::*;
3434
pub(crate) use task_shared::*;
@@ -42,7 +42,7 @@ pub(crate) fn create_memory_pool(
4242
match memory_pool_config.pool_type {
4343
MemoryPoolType::Unified => {
4444
// Set Comet memory pool for native
45-
let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
45+
let memory_pool = CometUnifiedMemoryPool::new(comet_task_memory_manager);
4646
Arc::new(TrackConsumersPool::new(
4747
memory_pool,
4848
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),

native/core/src/execution/memory_pools/unified_pool.rs

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,42 +23,43 @@ use std::{
2323
},
2424
};
2525

26-
use jni::objects::GlobalRef;
27-
28-
use datafusion::{
29-
common::{resources_datafusion_err, DataFusionError},
30-
execution::memory_pool::{MemoryPool, MemoryReservation},
31-
};
32-
3326
use crate::{
3427
errors::CometResult,
3528
jvm_bridge::{jni_call, JVMClasses},
3629
};
30+
use datafusion::{
31+
common::{resources_datafusion_err, DataFusionError},
32+
execution::memory_pool::{MemoryPool, MemoryReservation},
33+
};
34+
use jni::objects::GlobalRef;
35+
use log::warn;
3736

38-
/// A DataFusion `MemoryPool` implementation for Comet. Internally this is
39-
/// implemented via delegating calls to [`crate::jvm_bridge::CometTaskMemoryManager`].
40-
pub struct CometMemoryPool {
37+
/// A DataFusion `MemoryPool` implementation for Comet that delegates to
38+
/// Spark's off-heap executor memory pool via JNI by calling
39+
/// [`crate::jvm_bridge::CometTaskMemoryManager`].
40+
pub struct CometUnifiedMemoryPool {
4141
task_memory_manager_handle: Arc<GlobalRef>,
4242
used: AtomicUsize,
4343
}
4444

45-
impl Debug for CometMemoryPool {
45+
impl Debug for CometUnifiedMemoryPool {
4646
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
47-
f.debug_struct("CometMemoryPool")
47+
f.debug_struct("CometUnifiedMemoryPool")
4848
.field("used", &self.used.load(Relaxed))
4949
.finish()
5050
}
5151
}
5252

53-
impl CometMemoryPool {
54-
pub fn new(task_memory_manager_handle: Arc<GlobalRef>) -> CometMemoryPool {
53+
impl CometUnifiedMemoryPool {
54+
pub fn new(task_memory_manager_handle: Arc<GlobalRef>) -> CometUnifiedMemoryPool {
5555
Self {
5656
task_memory_manager_handle,
5757
used: AtomicUsize::new(0),
5858
}
5959
}
6060

61-
fn acquire(&self, additional: usize) -> CometResult<i64> {
61+
/// Request memory from Spark's off-heap memory pool via JNI
62+
fn acquire_from_spark(&self, additional: usize) -> CometResult<i64> {
6263
let mut env = JVMClasses::get_env()?;
6364
let handle = self.task_memory_manager_handle.as_obj();
6465
unsafe {
@@ -67,7 +68,8 @@ impl CometMemoryPool {
6768
}
6869
}
6970

70-
fn release(&self, size: usize) -> CometResult<()> {
71+
/// Release memory to Spark's off-heap memory pool via JNI
72+
fn release_to_spark(&self, size: usize) -> CometResult<()> {
7173
let mut env = JVMClasses::get_env()?;
7274
let handle = self.task_memory_manager_handle.as_obj();
7375
unsafe {
@@ -76,37 +78,42 @@ impl CometMemoryPool {
7678
}
7779
}
7880

79-
impl Drop for CometMemoryPool {
81+
impl Drop for CometUnifiedMemoryPool {
8082
fn drop(&mut self) {
8183
let used = self.used.load(Relaxed);
8284
if used != 0 {
83-
log::warn!("CometMemoryPool dropped with {used} bytes still reserved");
85+
warn!("CometUnifiedMemoryPool dropped with {used} bytes still reserved");
8486
}
8587
}
8688
}
8789

88-
unsafe impl Send for CometMemoryPool {}
89-
unsafe impl Sync for CometMemoryPool {}
90+
unsafe impl Send for CometUnifiedMemoryPool {}
91+
unsafe impl Sync for CometUnifiedMemoryPool {}
9092

91-
impl MemoryPool for CometMemoryPool {
93+
impl MemoryPool for CometUnifiedMemoryPool {
9294
fn grow(&self, reservation: &MemoryReservation, additional: usize) {
9395
self.try_grow(reservation, additional).unwrap();
9496
}
9597

9698
fn shrink(&self, _: &MemoryReservation, size: usize) {
97-
self.release(size)
99+
self.release_to_spark(size)
98100
.unwrap_or_else(|_| panic!("Failed to release {size} bytes"));
99-
self.used.fetch_sub(size, Relaxed);
101+
if let Err(prev) = self
102+
.used
103+
.fetch_update(Relaxed, Relaxed, |old| old.checked_sub(size))
104+
{
105+
panic!("overflow when releasing {size} of {prev} bytes");
106+
}
100107
}
101108

102109
fn try_grow(&self, _: &MemoryReservation, additional: usize) -> Result<(), DataFusionError> {
103110
if additional > 0 {
104-
let acquired = self.acquire(additional)?;
111+
let acquired = self.acquire_from_spark(additional)?;
105112
// If the number of bytes we acquired is less than the requested, return an error,
106113
// and hopefully will trigger spilling from the caller side.
107114
if acquired < additional as i64 {
108115
// Release the acquired bytes before throwing error
109-
self.release(acquired as usize)?;
116+
self.release_to_spark(acquired as usize)?;
110117

111118
return Err(resources_datafusion_err!(
112119
"Failed to acquire {} bytes, only got {}. Reserved: {}",
@@ -115,7 +122,16 @@ impl MemoryPool for CometMemoryPool {
115122
self.reserved()
116123
));
117124
}
118-
self.used.fetch_add(acquired as usize, Relaxed);
125+
if let Err(prev) = self
126+
.used
127+
.fetch_update(Relaxed, Relaxed, |old| old.checked_add(acquired as usize))
128+
{
129+
return Err(resources_datafusion_err!(
130+
"Failed to acquire {} bytes due to overflow. Reserved: {}",
131+
additional,
132+
prev
133+
));
134+
}
119135
}
120136
Ok(())
121137
}

0 commit comments

Comments
 (0)