Skip to content

Commit 65055f6

Browse files
authored
Merge pull request #3079 from huggingface/metal-tensor-fix-send-sync
[Metal] Ensure tensors are send/sync
2 parents 0950959 + a7fbc63 commit 65055f6

File tree

4 files changed

+58
-478
lines changed

4 files changed

+58
-478
lines changed

candle-core/src/metal_backend/device.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ impl MetalDevice {
123123
if flushed {
124124
self.drop_unused_buffers()?
125125
}
126-
Ok(command_buffer)
126+
Ok(command_buffer.clone())
127127
}
128128

129129
pub fn wait_until_completed(&self) -> Result<()> {

candle-core/tests/tensor_tests.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,27 @@ test_device!(asort, asort_cpu, asort_gpu, asort_metal);
16941694
test_device!(var, var_cpu, var_gpu, var_metal);
16951695
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
16961696

1697+
fn tensor_send_sync(device: &Device) -> Result<()> {
1698+
let tensor = Tensor::new(vec![1.0f32, 2.0, 3.0], device)?;
1699+
1700+
for _ in 0..10 {
1701+
let tensor = tensor.clone();
1702+
std::thread::spawn(move || {
1703+
let new = tensor.add(&tensor).unwrap();
1704+
let result: Vec<f32> = new.to_vec1().unwrap();
1705+
assert_eq!(result, vec![2.0f32, 4.0, 6.0]);
1706+
});
1707+
}
1708+
1709+
Ok(())
1710+
}
1711+
test_device!(
1712+
tensor_send_sync,
1713+
tensor_send_sync_cpu,
1714+
tensor_send_sync_gpu,
1715+
tensor_send_sync_metal
1716+
);
1717+
16971718
// There was originally a bug on the CPU implementation for randn
16981719
// https://github.com/huggingface/candle/issues/381
16991720
#[test]

candle-metal-kernels/src/metal/commands.rs

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,19 @@ use crate::metal::CommandBuffer;
22
use crate::MetalKernelError;
33
use objc2::{rc::Retained, runtime::ProtocolObject};
44
use objc2_metal::{MTLCommandBufferStatus, MTLCommandQueue, MTLCounterSet};
5+
use std::{
6+
collections::HashMap,
7+
sync::{Arc, Mutex},
8+
thread,
9+
};
510

611
// Use Retained when appropriate. Gives us a more elegant way of handling memory (peaks) than autoreleasepool.
712
// https://docs.rs/objc2/latest/objc2/rc/struct.Retained.html
813
pub type CommandQueue = Retained<ProtocolObject<dyn MTLCommandQueue>>;
914
pub type CounterSet = Retained<ProtocolObject<dyn MTLCounterSet>>;
1015

16+
type CommandBufferMap = HashMap<thread::ThreadId, CommandBuffer>;
17+
1118
pub struct Commands {
1219
/// Single command queue for the entire device.
1320
command_queue: CommandQueue,
@@ -20,13 +27,15 @@ pub struct Commands {
2027
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
2128
/// for their START time, but there's no guarantee that command buffer1 will finish before
2229
/// command buffer2 starts (or there are metal bugs there)
23-
command_buffer: CommandBuffer,
30+
command_buffers: Arc<Mutex<CommandBufferMap>>,
2431
/// Keeps track of the current amount of compute command encoders on the current
2532
/// command buffer
2633
/// Arc, RwLock because of the interior mutability.
2734
command_buffer_index: usize,
2835
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
2936
compute_per_buffer: usize,
37+
//capture: Option<Retained<MTLCaptureManager>>,
38+
//timestamp_counter_set: Option<CounterSet>,
3039
}
3140
unsafe impl Send for Commands {}
3241
unsafe impl Sync for Commands {}
@@ -43,44 +52,60 @@ impl Commands {
4352
pub fn new(command_queue: CommandQueue) -> Result<Self, MetalKernelError> {
4453
let command_buffer = create_command_buffer(&command_queue)?;
4554
command_buffer.enqueue();
55+
let command_buffers = HashMap::from([(thread::current().id(), command_buffer)]);
56+
let command_buffers = Arc::new(Mutex::new(command_buffers));
57+
4658
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
4759
Ok(val) => val.parse().unwrap_or(50),
4860
_ => 50,
4961
};
5062
Ok(Self {
5163
command_queue,
52-
command_buffer,
64+
command_buffers,
5365
command_buffer_index: 0,
5466
compute_per_buffer,
5567
})
5668
}
5769

5870
pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer), MetalKernelError> {
59-
let mut command_buffer = self.command_buffer.to_owned();
71+
let mut command_buffers = self.command_buffers.lock()?;
72+
let command_buffer =
73+
command_buffers
74+
.get_mut(&thread::current().id())
75+
.ok_or(MetalKernelError::LockError(
76+
"Command buffer map".to_string(),
77+
))?;
78+
6079
let mut flushed = false;
6180
if self.command_buffer_index > self.compute_per_buffer {
62-
self.command_buffer.commit();
63-
command_buffer = create_command_buffer(&self.command_queue)?;
64-
self.command_buffer = command_buffer.clone();
81+
command_buffer.commit();
82+
*command_buffer = create_command_buffer(&self.command_queue)?;
6583
self.command_buffer_index = 0;
6684
flushed = true;
6785
}
6886
self.command_buffer_index += 1;
69-
Ok((flushed, command_buffer))
87+
Ok((flushed, command_buffer.clone()))
7088
}
7189

7290
pub fn wait_until_completed(&mut self) -> Result<(), MetalKernelError> {
73-
match self.command_buffer.status() {
91+
let mut command_buffers = self.command_buffers.lock()?;
92+
let command_buffer =
93+
command_buffers
94+
.get_mut(&thread::current().id())
95+
.ok_or(MetalKernelError::LockError(
96+
"Command buffer map".to_string(),
97+
))?;
98+
match command_buffer.status() {
7499
MTLCommandBufferStatus::Committed
75100
| MTLCommandBufferStatus::Scheduled
76101
| MTLCommandBufferStatus::Completed => {
77102
panic!("Already committed");
78103
}
79104
_ => {}
80105
}
81-
self.command_buffer.commit();
82-
self.command_buffer.wait_until_completed();
83-
self.command_buffer = create_command_buffer(&self.command_queue)?;
106+
command_buffer.commit();
107+
command_buffer.wait_until_completed();
108+
*command_buffer = create_command_buffer(&self.command_queue)?;
84109

85110
Ok(())
86111
}

0 commit comments

Comments
 (0)