Skip to content

Commit 0bbf9c7

Browse files
committed
Ensure metal tensors are send/sync via thread isolated command buffer map
1 parent 402782c commit 0bbf9c7

File tree

4 files changed

+58
-14
lines changed

4 files changed

+58
-14
lines changed

candle-core/src/metal_backend/device.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ impl MetalDevice {
123123
if flushed {
124124
self.drop_unused_buffers()?
125125
}
126-
Ok(command_buffer)
126+
Ok(command_buffer.clone())
127127
}
128128

129129
pub fn wait_until_completed(&self) -> Result<()> {

candle-core/tests/tensor_tests.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,27 @@ test_device!(asort, asort_cpu, asort_gpu, asort_metal);
16941694
test_device!(var, var_cpu, var_gpu, var_metal);
16951695
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
16961696

1697+
fn tensor_send_sync(device: &Device) -> Result<()> {
1698+
let tensor = Tensor::new(vec![1.0f32, 2.0, 3.0], device)?;
1699+
1700+
for _ in 0..10 {
1701+
let tensor = tensor.clone();
1702+
std::thread::spawn(move || {
1703+
let new = tensor.add(&tensor).unwrap();
1704+
let result: Vec<f32> = new.to_vec1().unwrap();
1705+
assert_eq!(result, vec![2.0f32, 4.0, 6.0]);
1706+
});
1707+
}
1708+
1709+
Ok(())
1710+
}
1711+
test_device!(
1712+
tensor_send_sync,
1713+
tensor_send_sync_cpu,
1714+
tensor_send_sync_gpu,
1715+
tensor_send_sync_metal
1716+
);
1717+
16971718
// There was originally a bug on the CPU implementation for randn
16981719
// https://github.com/huggingface/candle/issues/381
16991720
#[test]

candle-metal-kernels/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ pub mod binary {
168168

169169
#[derive(thiserror::Error, Debug)]
170170
pub enum MetalKernelError {
171-
#[error("Could not lock kernel map: {0}")]
171+
#[error("Could not lock: {0}")]
172172
LockError(String),
173173
#[error("Error while loading library: {0}")]
174174
LoadLibraryError(String),

candle-metal-kernels/src/metal_utils.rs

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@ use objc2_metal::{
77
MTLCreateSystemDefaultDevice, MTLDataType, MTLDevice, MTLFunction, MTLFunctionConstantValues,
88
MTLLibrary, MTLResource, MTLResourceUsage, MTLSize,
99
};
10-
use std::{collections::HashMap, ffi::c_void, ptr, sync::Arc};
10+
use std::{
11+
collections::HashMap,
12+
ffi::c_void,
13+
ptr,
14+
sync::{Arc, Mutex},
15+
thread,
16+
};
1117

1218
// Use Retained when appropriate. Gives us a more elegant way of handling memory (peaks) than autoreleasepool.
1319
// https://docs.rs/objc2/latest/objc2/rc/struct.Retained.html
@@ -382,6 +388,7 @@ impl BlitCommandEncoder {
382388
}
383389

384390
pub type BufferMap = HashMap<(usize, MTLResourceOptions), Vec<Arc<Buffer>>>;
391+
type CommandBufferMap = HashMap<thread::ThreadId, CommandBuffer>;
385392
pub struct Commands {
386393
/// Single command queue for the entire device.
387394
command_queue: CommandQueue,
@@ -394,7 +401,7 @@ pub struct Commands {
394401
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
395402
/// for their START time, but there's no guarantee that command buffer1 will finish before
396403
/// command buffer2 starts (or there are metal bugs there)
397-
command_buffer: CommandBuffer,
404+
command_buffers: Arc<Mutex<CommandBufferMap>>,
398405
/// Keeps track of the current amount of compute command encoders on the current
399406
/// command buffer
400407
/// Arc, RwLock because of the interior mutability.
@@ -422,44 +429,60 @@ impl Commands {
422429
pub fn new(command_queue: CommandQueue) -> Result<Self, MetalKernelError> {
423430
let command_buffer = create_command_buffer(&command_queue)?;
424431
command_buffer.enqueue();
432+
let command_buffers = HashMap::from([(thread::current().id(), command_buffer)]);
433+
let command_buffers = Arc::new(Mutex::new(command_buffers));
434+
425435
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
426436
Ok(val) => val.parse().unwrap_or(50),
427437
_ => 50,
428438
};
429439
Ok(Self {
430440
command_queue,
431-
command_buffer,
441+
command_buffers,
432442
command_buffer_index: 0,
433443
compute_per_buffer,
434444
})
435445
}
436446

437447
pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer), MetalKernelError> {
438-
let mut command_buffer = self.command_buffer.to_owned();
448+
let mut command_buffers = self.command_buffers.lock()?;
449+
let command_buffer =
450+
command_buffers
451+
.get_mut(&thread::current().id())
452+
.ok_or(MetalKernelError::LockError(
453+
"Command buffer map".to_string(),
454+
))?;
455+
439456
let mut flushed = false;
440457
if self.command_buffer_index > self.compute_per_buffer {
441-
self.command_buffer.commit();
442-
command_buffer = create_command_buffer(&self.command_queue)?;
443-
self.command_buffer = command_buffer.clone();
458+
command_buffer.commit();
459+
*command_buffer = create_command_buffer(&self.command_queue)?;
444460
self.command_buffer_index = 0;
445461
flushed = true;
446462
}
447463
self.command_buffer_index += 1;
448-
Ok((flushed, command_buffer))
464+
Ok((flushed, command_buffer.clone()))
449465
}
450466

451467
pub fn wait_until_completed(&mut self) -> Result<(), MetalKernelError> {
452-
match self.command_buffer.status() {
468+
let mut command_buffers = self.command_buffers.lock()?;
469+
let command_buffer =
470+
command_buffers
471+
.get_mut(&thread::current().id())
472+
.ok_or(MetalKernelError::LockError(
473+
"Command buffer map".to_string(),
474+
))?;
475+
match command_buffer.status() {
453476
MTLCommandBufferStatus::Committed
454477
| MTLCommandBufferStatus::Scheduled
455478
| MTLCommandBufferStatus::Completed => {
456479
panic!("Already committed");
457480
}
458481
_ => {}
459482
}
460-
self.command_buffer.commit();
461-
self.command_buffer.wait_until_completed();
462-
self.command_buffer = create_command_buffer(&self.command_queue)?;
483+
command_buffer.commit();
484+
command_buffer.wait_until_completed();
485+
*command_buffer = create_command_buffer(&self.command_queue)?;
463486

464487
Ok(())
465488
}

0 commit comments

Comments
 (0)