Skip to content

Commit b6f7521

Browse files
authored
Do not require mut in memory reservation methods (#19759)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Prerequisite for the following PRs: - #19760 - #19761 Even if the api on the `MemoryPool` does not require `&mut self` for growing/shrinking the reserved size, the api in `MemoryReservation` does, making simple implementations irrepresentable without synchronization primitives. For example, the following would require a `Mutex` for concurrent access to the `MemoryReservation` in different threads, even though the `MemoryPool` doesn't: ```rust let mut stream: SendableRecordBatchStream = SendableRecordBatchStream::new(); let mem: Arc<MemoryReservation> = Arc::new(MemoryReservation::new_empty()); let mut builder = ReceiverStreamBuilder::new(10); let tx = builder.tx(); { let mem = mem.clone(); builder.spawn(async move { while let Some(msg) = stream.next().await { mem.try_grow(msg.unwrap().get_array_memory_size()); // ❌ `mem` is not mutable tx.send(msg).unwrap(); } }); } builder .build() .inspect_ok(|msg| mem.shrink(msg.get_array_memory_size())); // ❌ `mem` is not mutable ``` ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Make the methods in `MemoryReservation` require `&self` instead of `&mut self` for allowing concurrent shrink/grows from different tasks for the same reservation. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> yes, by current tests ## Are there any user-facing changes? Users can now safely call methods of `MemoryReservation` from different tasks without synchronization primitives. This is a backwards compatible API change, as it will work out of the box for current users, however, depending on their clippy configuration, they might see some new warnings about "unused muts" in their codebase. <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 4d63f8c commit b6f7521

File tree

12 files changed

+87
-79
lines changed

12 files changed

+87
-79
lines changed

datafusion-cli/src/exec.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ impl StatementExecutor {
269269
let options = task_ctx.session_config().options();
270270

271271
// Track memory usage for the query result if it's bounded
272-
let mut reservation =
272+
let reservation =
273273
MemoryConsumer::new("DataFusion-Cli").register(task_ctx.memory_pool());
274274

275275
if physical_plan.boundedness().is_unbounded() {

datafusion/core/src/execution/context/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2186,7 +2186,7 @@ mod tests {
21862186
// configure with same memory / disk manager
21872187
let memory_pool = ctx1.runtime_env().memory_pool.clone();
21882188

2189-
let mut reservation = MemoryConsumer::new("test").register(&memory_pool);
2189+
let reservation = MemoryConsumer::new("test").register(&memory_pool);
21902190
reservation.grow(100);
21912191

21922192
let disk_manager = ctx1.runtime_env().disk_manager.clone();

datafusion/datasource-parquet/src/file_format.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,7 +1360,7 @@ impl FileSink for ParquetSink {
13601360
parquet_props.clone(),
13611361
)
13621362
.await?;
1363-
let mut reservation = MemoryConsumer::new(format!("ParquetSink[{path}]"))
1363+
let reservation = MemoryConsumer::new(format!("ParquetSink[{path}]"))
13641364
.register(context.memory_pool());
13651365
file_write_tasks.spawn(async move {
13661366
while let Some(batch) = rx.recv().await {
@@ -1465,7 +1465,7 @@ impl DataSink for ParquetSink {
14651465
async fn column_serializer_task(
14661466
mut rx: Receiver<ArrowLeafColumn>,
14671467
mut writer: ArrowColumnWriter,
1468-
mut reservation: MemoryReservation,
1468+
reservation: MemoryReservation,
14691469
) -> Result<(ArrowColumnWriter, MemoryReservation)> {
14701470
while let Some(col) = rx.recv().await {
14711471
writer.write(&col)?;
@@ -1550,7 +1550,7 @@ fn spawn_rg_join_and_finalize_task(
15501550
rg_rows: usize,
15511551
pool: &Arc<dyn MemoryPool>,
15521552
) -> SpawnedTask<RBStreamSerializeResult> {
1553-
let mut rg_reservation =
1553+
let rg_reservation =
15541554
MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(pool);
15551555

15561556
SpawnedTask::spawn(async move {
@@ -1682,12 +1682,12 @@ async fn concatenate_parallel_row_groups(
16821682
mut object_store_writer: Box<dyn AsyncWrite + Send + Unpin>,
16831683
pool: Arc<dyn MemoryPool>,
16841684
) -> Result<ParquetMetaData> {
1685-
let mut file_reservation =
1685+
let file_reservation =
16861686
MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool);
16871687

16881688
while let Some(task) = serialize_rx.recv().await {
16891689
let result = task.join_unwind().await;
1690-
let (serialized_columns, mut rg_reservation, _cnt) =
1690+
let (serialized_columns, rg_reservation, _cnt) =
16911691
result.map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??;
16921692

16931693
let mut rg_out = parquet_writer.next_row_group()?;

datafusion/execution/src/memory_pool/mod.rs

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//! [`MemoryPool`] for memory management during query execution, [`proxy`] for
1919
//! help with allocation accounting.
2020
21-
use datafusion_common::{Result, internal_err};
21+
use datafusion_common::{Result, internal_datafusion_err};
2222
use std::hash::{Hash, Hasher};
2323
use std::{cmp::Ordering, sync::Arc, sync::atomic};
2424

@@ -322,7 +322,7 @@ impl MemoryConsumer {
322322
pool: Arc::clone(pool),
323323
consumer: self,
324324
}),
325-
size: 0,
325+
size: atomic::AtomicUsize::new(0),
326326
}
327327
}
328328
}
@@ -351,13 +351,13 @@ impl Drop for SharedRegistration {
351351
#[derive(Debug)]
352352
pub struct MemoryReservation {
353353
registration: Arc<SharedRegistration>,
354-
size: usize,
354+
size: atomic::AtomicUsize,
355355
}
356356

357357
impl MemoryReservation {
358358
/// Returns the size of this reservation in bytes
359359
pub fn size(&self) -> usize {
360-
self.size
360+
self.size.load(atomic::Ordering::Relaxed)
361361
}
362362

363363
/// Returns [MemoryConsumer] for this [MemoryReservation]
@@ -367,8 +367,8 @@ impl MemoryReservation {
367367

368368
/// Frees all bytes from this reservation back to the underlying
369369
/// pool, returning the number of bytes freed.
370-
pub fn free(&mut self) -> usize {
371-
let size = self.size;
370+
pub fn free(&self) -> usize {
371+
let size = self.size.load(atomic::Ordering::Relaxed);
372372
if size != 0 {
373373
self.shrink(size)
374374
}
@@ -380,60 +380,62 @@ impl MemoryReservation {
380380
/// # Panics
381381
///
382382
/// Panics if `capacity` exceeds [`Self::size`]
383-
pub fn shrink(&mut self, capacity: usize) {
384-
let new_size = self.size.checked_sub(capacity).unwrap();
383+
pub fn shrink(&self, capacity: usize) {
384+
self.size.fetch_sub(capacity, atomic::Ordering::Relaxed);
385385
self.registration.pool.shrink(self, capacity);
386-
self.size = new_size
387386
}
388387

389388
/// Tries to free `capacity` bytes from this reservation
390389
/// if `capacity` does not exceed [`Self::size`]
391390
/// Returns new reservation size
392391
/// or error if shrinking capacity is more than allocated size
393-
pub fn try_shrink(&mut self, capacity: usize) -> Result<usize> {
394-
if let Some(new_size) = self.size.checked_sub(capacity) {
395-
self.registration.pool.shrink(self, capacity);
396-
self.size = new_size;
397-
Ok(new_size)
398-
} else {
399-
internal_err!(
400-
"Cannot free the capacity {capacity} out of allocated size {}",
401-
self.size
392+
pub fn try_shrink(&self, capacity: usize) -> Result<usize> {
393+
let updated = self.size.fetch_update(
394+
atomic::Ordering::Relaxed,
395+
atomic::Ordering::Relaxed,
396+
|prev| prev.checked_sub(capacity),
397+
);
398+
updated.map_err(|_| {
399+
let prev = self.size.load(atomic::Ordering::Relaxed);
400+
internal_datafusion_err!(
401+
"Cannot free the capacity {capacity} out of allocated size {prev}"
402402
)
403-
}
403+
})
404404
}
405405

406406
/// Sets the size of this reservation to `capacity`
407-
pub fn resize(&mut self, capacity: usize) {
408-
match capacity.cmp(&self.size) {
409-
Ordering::Greater => self.grow(capacity - self.size),
410-
Ordering::Less => self.shrink(self.size - capacity),
407+
pub fn resize(&self, capacity: usize) {
408+
let size = self.size.load(atomic::Ordering::Relaxed);
409+
match capacity.cmp(&size) {
410+
Ordering::Greater => self.grow(capacity - size),
411+
Ordering::Less => self.shrink(size - capacity),
411412
_ => {}
412413
}
413414
}
414415

415416
/// Try to set the size of this reservation to `capacity`
416-
pub fn try_resize(&mut self, capacity: usize) -> Result<()> {
417-
match capacity.cmp(&self.size) {
418-
Ordering::Greater => self.try_grow(capacity - self.size)?,
419-
Ordering::Less => self.shrink(self.size - capacity),
417+
pub fn try_resize(&self, capacity: usize) -> Result<()> {
418+
let size = self.size.load(atomic::Ordering::Relaxed);
419+
match capacity.cmp(&size) {
420+
Ordering::Greater => self.try_grow(capacity - size)?,
421+
Ordering::Less => self.shrink(size - capacity),
420422
_ => {}
421423
};
422424
Ok(())
423425
}
424426

425427
/// Increase the size of this reservation by `capacity` bytes
426-
pub fn grow(&mut self, capacity: usize) {
428+
pub fn grow(&self, capacity: usize) {
427429
self.registration.pool.grow(self, capacity);
428-
self.size += capacity;
430+
self.size.fetch_add(capacity, atomic::Ordering::Relaxed);
429431
}
430432

431433
/// Try to increase the size of this reservation by `capacity`
432434
/// bytes, returning error if there is insufficient capacity left
433435
/// in the pool.
434-
pub fn try_grow(&mut self, capacity: usize) -> Result<()> {
436+
pub fn try_grow(&self, capacity: usize) -> Result<()> {
435437
self.registration.pool.try_grow(self, capacity)?;
436-
self.size += capacity;
438+
self.size.fetch_add(capacity, atomic::Ordering::Relaxed);
437439
Ok(())
438440
}
439441

@@ -447,26 +449,32 @@ impl MemoryReservation {
447449
/// # Panics
448450
///
449451
/// Panics if `capacity` exceeds [`Self::size`]
450-
pub fn split(&mut self, capacity: usize) -> MemoryReservation {
451-
self.size = self.size.checked_sub(capacity).unwrap();
452+
pub fn split(&self, capacity: usize) -> MemoryReservation {
453+
self.size
454+
.fetch_update(
455+
atomic::Ordering::Relaxed,
456+
atomic::Ordering::Relaxed,
457+
|prev| prev.checked_sub(capacity),
458+
)
459+
.unwrap();
452460
Self {
453-
size: capacity,
461+
size: atomic::AtomicUsize::new(capacity),
454462
registration: Arc::clone(&self.registration),
455463
}
456464
}
457465

458466
/// Returns a new empty [`MemoryReservation`] with the same [`MemoryConsumer`]
459467
pub fn new_empty(&self) -> Self {
460468
Self {
461-
size: 0,
469+
size: atomic::AtomicUsize::new(0),
462470
registration: Arc::clone(&self.registration),
463471
}
464472
}
465473

466474
/// Splits off all the bytes from this [`MemoryReservation`] into
467475
/// a new [`MemoryReservation`] with the same [`MemoryConsumer`]
468476
pub fn take(&mut self) -> MemoryReservation {
469-
self.split(self.size)
477+
self.split(self.size.load(atomic::Ordering::Relaxed))
470478
}
471479
}
472480

@@ -492,7 +500,7 @@ mod tests {
492500
#[test]
493501
fn test_memory_pool_underflow() {
494502
let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
495-
let mut a1 = MemoryConsumer::new("a1").register(&pool);
503+
let a1 = MemoryConsumer::new("a1").register(&pool);
496504
assert_eq!(pool.reserved(), 0);
497505

498506
a1.grow(100);
@@ -507,7 +515,7 @@ mod tests {
507515
a1.try_grow(30).unwrap();
508516
assert_eq!(pool.reserved(), 30);
509517

510-
let mut a2 = MemoryConsumer::new("a2").register(&pool);
518+
let a2 = MemoryConsumer::new("a2").register(&pool);
511519
a2.try_grow(25).unwrap_err();
512520
assert_eq!(pool.reserved(), 30);
513521

@@ -521,7 +529,7 @@ mod tests {
521529
#[test]
522530
fn test_split() {
523531
let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
524-
let mut r1 = MemoryConsumer::new("r1").register(&pool);
532+
let r1 = MemoryConsumer::new("r1").register(&pool);
525533

526534
r1.try_grow(20).unwrap();
527535
assert_eq!(r1.size(), 20);
@@ -542,10 +550,10 @@ mod tests {
542550
#[test]
543551
fn test_new_empty() {
544552
let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
545-
let mut r1 = MemoryConsumer::new("r1").register(&pool);
553+
let r1 = MemoryConsumer::new("r1").register(&pool);
546554

547555
r1.try_grow(20).unwrap();
548-
let mut r2 = r1.new_empty();
556+
let r2 = r1.new_empty();
549557
r2.try_grow(5).unwrap();
550558

551559
assert_eq!(r1.size(), 20);
@@ -559,7 +567,7 @@ mod tests {
559567
let mut r1 = MemoryConsumer::new("r1").register(&pool);
560568

561569
r1.try_grow(20).unwrap();
562-
let mut r2 = r1.take();
570+
let r2 = r1.take();
563571
r2.try_grow(5).unwrap();
564572

565573
assert_eq!(r1.size(), 0);

0 commit comments

Comments
 (0)