Skip to content

Commit cfc4cbb

Browse files
authored
chore: Refactor Memory Pools (#1662)
* chore: Refactor Memory Pools * merge mod lines * add licenses to new rust files * fix typo
1 parent 88525e3 commit cfc4cbb

File tree

7 files changed

+289
-202
lines changed

7 files changed

+289
-202
lines changed

native/core/src/execution/jni_api.rs

Lines changed: 12 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@
1717

1818
//! Define JNI APIs which can be called from Java/Scala.
1919
20-
use super::{serde, utils::SparkArrowConvert, CometMemoryPool};
20+
use super::{serde, utils::SparkArrowConvert};
2121
use arrow::array::RecordBatch;
2222
use arrow::datatypes::DataType as ArrowDataType;
23-
use datafusion::execution::memory_pool::{
24-
FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, UnboundedMemoryPool,
25-
};
23+
use datafusion::execution::memory_pool::MemoryPool;
2624
use datafusion::{
2725
execution::{disk_manager::DiskManagerConfig, runtime_env::RuntimeEnv},
2826
physical_plan::{display::DisplayableExecutionPlan, SendableRecordBatchStream},
@@ -40,7 +38,7 @@ use jni::{
4038
};
4139
use std::path::PathBuf;
4240
use std::time::{Duration, Instant};
43-
use std::{collections::HashMap, sync::Arc, task::Poll};
41+
use std::{sync::Arc, task::Poll};
4442

4543
use crate::{
4644
errors::{try_unwrap_or_throw, CometError, CometResult},
@@ -60,16 +58,16 @@ use jni::{
6058
objects::GlobalRef,
6159
sys::{jboolean, jdouble, jintArray, jobjectArray, jstring},
6260
};
63-
use std::num::NonZeroUsize;
64-
use std::sync::Mutex;
6561
use tokio::runtime::Runtime;
6662

67-
use crate::execution::fair_memory_pool::CometFairMemoryPool;
63+
use crate::execution::memory_pools::{
64+
create_memory_pool, handle_task_shared_pool_release, parse_memory_pool_config, MemoryPoolConfig,
65+
};
6866
use crate::execution::operators::ScanExec;
6967
use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec};
7068
use crate::execution::spark_plan::SparkPlan;
7169
use log::info;
72-
use once_cell::sync::{Lazy, OnceCell};
70+
use once_cell::sync::Lazy;
7371

7472
static TOKIO_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
7573
let mut builder = tokio::runtime::Builder::new_multi_thread();
@@ -130,51 +128,6 @@ struct ExecutionContext {
130128
pub memory_pool_config: MemoryPoolConfig,
131129
}
132130

133-
#[derive(PartialEq, Eq)]
134-
enum MemoryPoolType {
135-
Unified,
136-
FairUnified,
137-
Greedy,
138-
FairSpill,
139-
GreedyTaskShared,
140-
FairSpillTaskShared,
141-
GreedyGlobal,
142-
FairSpillGlobal,
143-
Unbounded,
144-
}
145-
146-
struct MemoryPoolConfig {
147-
pool_type: MemoryPoolType,
148-
pool_size: usize,
149-
}
150-
151-
impl MemoryPoolConfig {
152-
fn new(pool_type: MemoryPoolType, pool_size: usize) -> Self {
153-
Self {
154-
pool_type,
155-
pool_size,
156-
}
157-
}
158-
}
159-
160-
/// The per-task memory pools keyed by task attempt id.
161-
static TASK_SHARED_MEMORY_POOLS: Lazy<Mutex<HashMap<i64, PerTaskMemoryPool>>> =
162-
Lazy::new(|| Mutex::new(HashMap::new()));
163-
164-
struct PerTaskMemoryPool {
165-
memory_pool: Arc<dyn MemoryPool>,
166-
num_plans: usize,
167-
}
168-
169-
impl PerTaskMemoryPool {
170-
fn new(memory_pool: Arc<dyn MemoryPool>) -> Self {
171-
Self {
172-
memory_pool,
173-
num_plans: 0,
174-
}
175-
}
176-
}
177-
178131
/// Accept serialized query plan and return the address of the native query plan.
179132
/// # Safety
180133
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
@@ -321,134 +274,6 @@ fn prepare_datafusion_session_context(
321274
Ok(session_ctx)
322275
}
323276

324-
fn parse_memory_pool_config(
325-
off_heap_mode: bool,
326-
memory_pool_type: String,
327-
memory_limit: i64,
328-
memory_limit_per_task: i64,
329-
) -> CometResult<MemoryPoolConfig> {
330-
let pool_size = memory_limit as usize;
331-
let memory_pool_config = if off_heap_mode {
332-
match memory_pool_type.as_str() {
333-
"fair_unified" => MemoryPoolConfig::new(MemoryPoolType::FairUnified, pool_size),
334-
"default" | "unified" => {
335-
// the `unified` memory pool interacts with Spark's memory pool to allocate
336-
// memory therefore does not need a size to be explicitly set. The pool size
337-
// shared with Spark is set by `spark.memory.offHeap.size`.
338-
MemoryPoolConfig::new(MemoryPoolType::Unified, 0)
339-
}
340-
_ => {
341-
return Err(CometError::Config(format!(
342-
"Unsupported memory pool type for off-heap mode: {}",
343-
memory_pool_type
344-
)))
345-
}
346-
}
347-
} else {
348-
// Use the memory pool from DF
349-
let pool_size_per_task = memory_limit_per_task as usize;
350-
match memory_pool_type.as_str() {
351-
"fair_spill_task_shared" => {
352-
MemoryPoolConfig::new(MemoryPoolType::FairSpillTaskShared, pool_size_per_task)
353-
}
354-
"default" | "greedy_task_shared" => {
355-
MemoryPoolConfig::new(MemoryPoolType::GreedyTaskShared, pool_size_per_task)
356-
}
357-
"fair_spill_global" => {
358-
MemoryPoolConfig::new(MemoryPoolType::FairSpillGlobal, pool_size)
359-
}
360-
"greedy_global" => MemoryPoolConfig::new(MemoryPoolType::GreedyGlobal, pool_size),
361-
"fair_spill" => MemoryPoolConfig::new(MemoryPoolType::FairSpill, pool_size_per_task),
362-
"greedy" => MemoryPoolConfig::new(MemoryPoolType::Greedy, pool_size_per_task),
363-
"unbounded" => MemoryPoolConfig::new(MemoryPoolType::Unbounded, 0),
364-
_ => {
365-
return Err(CometError::Config(format!(
366-
"Unsupported memory pool type for on-heap mode: {}",
367-
memory_pool_type
368-
)))
369-
}
370-
}
371-
};
372-
Ok(memory_pool_config)
373-
}
374-
375-
fn create_memory_pool(
376-
memory_pool_config: &MemoryPoolConfig,
377-
comet_task_memory_manager: Arc<GlobalRef>,
378-
task_attempt_id: i64,
379-
) -> Arc<dyn MemoryPool> {
380-
const NUM_TRACKED_CONSUMERS: usize = 10;
381-
match memory_pool_config.pool_type {
382-
MemoryPoolType::Unified => {
383-
// Set Comet memory pool for native
384-
let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
385-
Arc::new(TrackConsumersPool::new(
386-
memory_pool,
387-
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
388-
))
389-
}
390-
MemoryPoolType::FairUnified => {
391-
// Set Comet fair memory pool for native
392-
let memory_pool =
393-
CometFairMemoryPool::new(comet_task_memory_manager, memory_pool_config.pool_size);
394-
Arc::new(TrackConsumersPool::new(
395-
memory_pool,
396-
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
397-
))
398-
}
399-
MemoryPoolType::Greedy => Arc::new(TrackConsumersPool::new(
400-
GreedyMemoryPool::new(memory_pool_config.pool_size),
401-
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
402-
)),
403-
MemoryPoolType::FairSpill => Arc::new(TrackConsumersPool::new(
404-
FairSpillPool::new(memory_pool_config.pool_size),
405-
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
406-
)),
407-
MemoryPoolType::GreedyGlobal => {
408-
static GLOBAL_MEMORY_POOL_GREEDY: OnceCell<Arc<dyn MemoryPool>> = OnceCell::new();
409-
let memory_pool = GLOBAL_MEMORY_POOL_GREEDY.get_or_init(|| {
410-
Arc::new(TrackConsumersPool::new(
411-
GreedyMemoryPool::new(memory_pool_config.pool_size),
412-
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
413-
))
414-
});
415-
Arc::clone(memory_pool)
416-
}
417-
MemoryPoolType::FairSpillGlobal => {
418-
static GLOBAL_MEMORY_POOL_FAIR: OnceCell<Arc<dyn MemoryPool>> = OnceCell::new();
419-
let memory_pool = GLOBAL_MEMORY_POOL_FAIR.get_or_init(|| {
420-
Arc::new(TrackConsumersPool::new(
421-
FairSpillPool::new(memory_pool_config.pool_size),
422-
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
423-
))
424-
});
425-
Arc::clone(memory_pool)
426-
}
427-
MemoryPoolType::GreedyTaskShared | MemoryPoolType::FairSpillTaskShared => {
428-
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
429-
let per_task_memory_pool =
430-
memory_pool_map.entry(task_attempt_id).or_insert_with(|| {
431-
let pool: Arc<dyn MemoryPool> =
432-
if memory_pool_config.pool_type == MemoryPoolType::GreedyTaskShared {
433-
Arc::new(TrackConsumersPool::new(
434-
GreedyMemoryPool::new(memory_pool_config.pool_size),
435-
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
436-
))
437-
} else {
438-
Arc::new(TrackConsumersPool::new(
439-
FairSpillPool::new(memory_pool_config.pool_size),
440-
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
441-
))
442-
};
443-
PerTaskMemoryPool::new(pool)
444-
});
445-
per_task_memory_pool.num_plans += 1;
446-
Arc::clone(&per_task_memory_pool.memory_pool)
447-
}
448-
MemoryPoolType::Unbounded => Arc::new(UnboundedMemoryPool::default()),
449-
}
450-
}
451-
452277
/// Prepares arrow arrays for output.
453278
fn prepare_output(
454279
env: &mut JNIEnv,
@@ -643,22 +468,11 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
643468
// Update metrics
644469
update_metrics(&mut env, execution_context)?;
645470

646-
if execution_context.memory_pool_config.pool_type == MemoryPoolType::FairSpillTaskShared
647-
|| execution_context.memory_pool_config.pool_type == MemoryPoolType::GreedyTaskShared
648-
{
649-
// Decrement the number of native plans using the per-task shared memory pool, and
650-
// remove the memory pool if the released native plan is the last native plan using it.
651-
let task_attempt_id = execution_context.task_attempt_id;
652-
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
653-
if let Some(per_task_memory_pool) = memory_pool_map.get_mut(&task_attempt_id) {
654-
per_task_memory_pool.num_plans -= 1;
655-
if per_task_memory_pool.num_plans == 0 {
656-
// Drop the memory pool from the per-task memory pool map if there are no
657-
// more native plans using it.
658-
memory_pool_map.remove(&task_attempt_id);
659-
}
660-
}
661-
}
471+
handle_task_shared_pool_release(
472+
execution_context.memory_pool_config.pool_type,
473+
execution_context.task_attempt_id,
474+
);
475+
662476
let _: Box<ExecutionContext> = Box::from_raw(execution_context);
663477
Ok(())
664478
})
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::errors::{CometError, CometResult};
19+
20+
#[derive(Copy, Clone, PartialEq, Eq)]
21+
pub(crate) enum MemoryPoolType {
22+
Unified,
23+
FairUnified,
24+
Greedy,
25+
FairSpill,
26+
GreedyTaskShared,
27+
FairSpillTaskShared,
28+
GreedyGlobal,
29+
FairSpillGlobal,
30+
Unbounded,
31+
}
32+
33+
impl MemoryPoolType {
34+
pub(crate) fn is_task_shared(&self) -> bool {
35+
matches!(
36+
self,
37+
MemoryPoolType::GreedyTaskShared | MemoryPoolType::FairSpillTaskShared
38+
)
39+
}
40+
}
41+
42+
pub(crate) struct MemoryPoolConfig {
43+
pub(crate) pool_type: MemoryPoolType,
44+
pub(crate) pool_size: usize,
45+
}
46+
47+
impl MemoryPoolConfig {
48+
pub(crate) fn new(pool_type: MemoryPoolType, pool_size: usize) -> Self {
49+
Self {
50+
pool_type,
51+
pool_size,
52+
}
53+
}
54+
}
55+
56+
pub(crate) fn parse_memory_pool_config(
57+
off_heap_mode: bool,
58+
memory_pool_type: String,
59+
memory_limit: i64,
60+
memory_limit_per_task: i64,
61+
) -> CometResult<MemoryPoolConfig> {
62+
let pool_size = memory_limit as usize;
63+
let memory_pool_config = if off_heap_mode {
64+
match memory_pool_type.as_str() {
65+
"fair_unified" => MemoryPoolConfig::new(MemoryPoolType::FairUnified, pool_size),
66+
"default" | "unified" => {
67+
// the `unified` memory pool interacts with Spark's memory pool to allocate
68+
// memory therefore does not need a size to be explicitly set. The pool size
69+
// shared with Spark is set by `spark.memory.offHeap.size`.
70+
MemoryPoolConfig::new(MemoryPoolType::Unified, 0)
71+
}
72+
_ => {
73+
return Err(CometError::Config(format!(
74+
"Unsupported memory pool type for off-heap mode: {}",
75+
memory_pool_type
76+
)))
77+
}
78+
}
79+
} else {
80+
// Use the memory pool from DF
81+
let pool_size_per_task = memory_limit_per_task as usize;
82+
match memory_pool_type.as_str() {
83+
"fair_spill_task_shared" => {
84+
MemoryPoolConfig::new(MemoryPoolType::FairSpillTaskShared, pool_size_per_task)
85+
}
86+
"default" | "greedy_task_shared" => {
87+
MemoryPoolConfig::new(MemoryPoolType::GreedyTaskShared, pool_size_per_task)
88+
}
89+
"fair_spill_global" => {
90+
MemoryPoolConfig::new(MemoryPoolType::FairSpillGlobal, pool_size)
91+
}
92+
"greedy_global" => MemoryPoolConfig::new(MemoryPoolType::GreedyGlobal, pool_size),
93+
"fair_spill" => MemoryPoolConfig::new(MemoryPoolType::FairSpill, pool_size_per_task),
94+
"greedy" => MemoryPoolConfig::new(MemoryPoolType::Greedy, pool_size_per_task),
95+
"unbounded" => MemoryPoolConfig::new(MemoryPoolType::Unbounded, 0),
96+
_ => {
97+
return Err(CometError::Config(format!(
98+
"Unsupported memory pool type for on-heap mode: {}",
99+
memory_pool_type
100+
)))
101+
}
102+
}
103+
};
104+
Ok(memory_pool_config)
105+
}
File renamed without changes.

0 commit comments

Comments
 (0)