@@ -7,7 +7,13 @@ use objc2_metal::{
77 MTLCreateSystemDefaultDevice , MTLDataType , MTLDevice , MTLFunction , MTLFunctionConstantValues ,
88 MTLLibrary , MTLResource , MTLResourceUsage , MTLSize ,
99} ;
10- use std:: { collections:: HashMap , ffi:: c_void, ptr, sync:: Arc } ;
10+ use std:: {
11+ collections:: HashMap ,
12+ ffi:: c_void,
13+ ptr,
14+ sync:: { Arc , Mutex } ,
15+ thread,
16+ } ;
1117
1218// Use Retained when appropriate. Gives us a more elegant way of handling memory (peaks) than autoreleasepool.
1319// https://docs.rs/objc2/latest/objc2/rc/struct.Retained.html
@@ -382,6 +388,7 @@ impl BlitCommandEncoder {
382388}
383389
384390pub type BufferMap = HashMap < ( usize , MTLResourceOptions ) , Vec < Arc < Buffer > > > ;
391+ type CommandBufferMap = HashMap < thread:: ThreadId , CommandBuffer > ;
385392pub struct Commands {
386393 /// Single command queue for the entire device.
387394 command_queue : CommandQueue ,
@@ -394,7 +401,7 @@ pub struct Commands {
394401 /// Despite what the documentation says, command buffers are NOT ordered. They are ordered
395402 /// for their START time, but there's no guarantee that command buffer1 will finish before
396403 /// command buffer2 starts (or there are metal bugs there)
397- command_buffer : CommandBuffer ,
404+ command_buffers : Arc < Mutex < CommandBufferMap > > ,
398405 /// Keeps track of the current amount of compute command encoders on the current
399406 /// command buffer
400407 /// Arc, RwLock because of the interior mutability.
@@ -422,44 +429,60 @@ impl Commands {
422429 pub fn new ( command_queue : CommandQueue ) -> Result < Self , MetalKernelError > {
423430 let command_buffer = create_command_buffer ( & command_queue) ?;
424431 command_buffer. enqueue ( ) ;
432+ let command_buffers = HashMap :: from ( [ ( thread:: current ( ) . id ( ) , command_buffer) ] ) ;
433+ let command_buffers = Arc :: new ( Mutex :: new ( command_buffers) ) ;
434+
425435 let compute_per_buffer = match std:: env:: var ( "CANDLE_METAL_COMPUTE_PER_BUFFER" ) {
426436 Ok ( val) => val. parse ( ) . unwrap_or ( 50 ) ,
427437 _ => 50 ,
428438 } ;
429439 Ok ( Self {
430440 command_queue,
431- command_buffer ,
441+ command_buffers ,
432442 command_buffer_index : 0 ,
433443 compute_per_buffer,
434444 } )
435445 }
436446
437447 pub fn command_buffer ( & mut self ) -> Result < ( bool , CommandBuffer ) , MetalKernelError > {
438- let mut command_buffer = self . command_buffer . to_owned ( ) ;
448+ let mut command_buffers = self . command_buffers . lock ( ) ?;
449+ let command_buffer =
450+ command_buffers
451+ . get_mut ( & thread:: current ( ) . id ( ) )
452+ . ok_or ( MetalKernelError :: LockError (
453+ "Command buffer map" . to_string ( ) ,
454+ ) ) ?;
455+
439456 let mut flushed = false ;
440457 if self . command_buffer_index > self . compute_per_buffer {
441- self . command_buffer . commit ( ) ;
442- command_buffer = create_command_buffer ( & self . command_queue ) ?;
443- self . command_buffer = command_buffer. clone ( ) ;
458+ command_buffer. commit ( ) ;
459+ * command_buffer = create_command_buffer ( & self . command_queue ) ?;
444460 self . command_buffer_index = 0 ;
445461 flushed = true ;
446462 }
447463 self . command_buffer_index += 1 ;
448- Ok ( ( flushed, command_buffer) )
464+ Ok ( ( flushed, command_buffer. clone ( ) ) )
449465 }
450466
451467 pub fn wait_until_completed ( & mut self ) -> Result < ( ) , MetalKernelError > {
452- match self . command_buffer . status ( ) {
468+ let mut command_buffers = self . command_buffers . lock ( ) ?;
469+ let command_buffer =
470+ command_buffers
471+ . get_mut ( & thread:: current ( ) . id ( ) )
472+ . ok_or ( MetalKernelError :: LockError (
473+ "Command buffer map" . to_string ( ) ,
474+ ) ) ?;
475+ match command_buffer. status ( ) {
453476 MTLCommandBufferStatus :: Committed
454477 | MTLCommandBufferStatus :: Scheduled
455478 | MTLCommandBufferStatus :: Completed => {
456479 panic ! ( "Already committed" ) ;
457480 }
458481 _ => { }
459482 }
460- self . command_buffer . commit ( ) ;
461- self . command_buffer . wait_until_completed ( ) ;
462- self . command_buffer = create_command_buffer ( & self . command_queue ) ?;
483+ command_buffer. commit ( ) ;
484+ command_buffer. wait_until_completed ( ) ;
485+ * command_buffer = create_command_buffer ( & self . command_queue ) ?;
463486
464487 Ok ( ( ) )
465488 }
0 commit comments