@@ -870,3 +870,121 @@ def test_engines_share_mlx_executor(self, mock_model, mock_tokenizer):
870870 finally :
871871 engine1 .close ()
872872 engine2 .close ()
873+
874+ @pytest .mark .asyncio
875+ async def test_shared_executor_serializes_concurrent_tasks (self ):
876+ """Concurrent submissions to shared executor must never overlap (#85).
877+
878+ Simulates two engines submitting work simultaneously and verifies
879+ that tasks run one at a time (no concurrent execution).
880+ """
881+ import threading
882+ import time
883+ from omlx .engine_core import get_mlx_executor
884+
885+ executor = get_mlx_executor ()
886+ loop = asyncio .get_running_loop ()
887+
888+ active_count = 0
889+ max_concurrent = 0
890+ lock = threading .Lock ()
891+
892+ def simulated_step (task_id : str , duration : float = 0.05 ):
893+ """Simulate a scheduler.step() that takes some time."""
894+ nonlocal active_count , max_concurrent
895+ with lock :
896+ active_count += 1
897+ if active_count > max_concurrent :
898+ max_concurrent = active_count
899+ time .sleep (duration )
900+ with lock :
901+ active_count -= 1
902+ return task_id
903+
904+ # Submit multiple tasks concurrently (simulating two engines)
905+ tasks = [
906+ loop .run_in_executor (executor , simulated_step , "engine_a_step1" ),
907+ loop .run_in_executor (executor , simulated_step , "engine_b_step1" ),
908+ loop .run_in_executor (executor , simulated_step , "engine_a_step2" ),
909+ loop .run_in_executor (executor , simulated_step , "engine_b_step2" ),
910+ ]
911+ results = await asyncio .gather (* tasks )
912+
913+ # All tasks completed
914+ assert set (results ) == {
915+ "engine_a_step1" , "engine_b_step1" ,
916+ "engine_a_step2" , "engine_b_step2" ,
917+ }
918+ # Critical: no two tasks ever ran at the same time
919+ assert max_concurrent == 1 , (
920+ f"Expected max 1 concurrent task, got { max_concurrent } . "
921+ f"Shared executor failed to serialize MLX operations."
922+ )
923+
924+ @pytest .mark .asyncio
925+ async def test_two_engine_loops_serialize_on_shared_executor (
926+ self , mock_model , mock_tokenizer
927+ ):
928+ """Two engines running their loops must serialize step() calls (#85).
929+
930+ Creates two EngineCore instances with mock schedulers, starts both
931+ engine loops, and verifies their scheduler.step() calls never overlap.
932+ """
933+ import threading
934+ import time
935+
936+ active_count = 0
937+ max_concurrent = 0
938+ total_steps = 0
939+ lock = threading .Lock ()
940+
941+ def make_tracked_step ():
942+ """Create a step function that tracks concurrency."""
943+ from omlx .scheduler import SchedulerOutput
944+
945+ def tracked_step ():
946+ nonlocal active_count , max_concurrent , total_steps
947+ with lock :
948+ active_count += 1
949+ total_steps += 1
950+ if active_count > max_concurrent :
951+ max_concurrent = active_count
952+ time .sleep (0.01 ) # Simulate GPU work
953+ with lock :
954+ active_count -= 1
955+ return SchedulerOutput (outputs = [])
956+
957+ return tracked_step
958+
959+ with patch ("omlx.engine_core.get_registry" ) as mock_registry :
960+ mock_registry .return_value .acquire .return_value = True
961+
962+ engine1 = EngineCore (model = mock_model , tokenizer = mock_tokenizer )
963+ engine2 = EngineCore (model = mock_model , tokenizer = mock_tokenizer )
964+
965+ # Wire up tracked step functions
966+ engine1 .scheduler .step = make_tracked_step ()
967+ engine2 .scheduler .step = make_tracked_step ()
968+ engine1 .scheduler .has_requests = lambda : True
969+ engine2 .scheduler .has_requests = lambda : True
970+
971+ try :
972+ await engine1 .start ()
973+ await engine2 .start ()
974+
975+ # Let both engines run for a bit
976+ await asyncio .sleep (0.3 )
977+ finally :
978+ await engine1 .stop ()
979+ await engine2 .stop ()
980+ engine1 .close ()
981+ engine2 .close ()
982+
983+ assert total_steps >= 4 , (
984+ f"Expected at least 4 steps from two engines, got { total_steps } "
985+ )
986+ assert max_concurrent == 1 , (
987+ f"Expected max 1 concurrent step(), got { max_concurrent } . "
988+ f"Two engines ran MLX operations in parallel — would cause "
989+ f"Metal command buffer races in production."
990+ )
0 commit comments