Skip to content

Commit a9d9007

Browse files
committed
test: add concurrency serialization tests for shared MLX executor
Verify that the shared executor actually prevents concurrent execution: - test_shared_executor_serializes_concurrent_tasks: 4 tasks submitted simultaneously, asserts max_concurrent == 1 - test_two_engine_loops_serialize_on_shared_executor: two EngineCore instances with mock schedulers run their loops, asserts step() calls never overlap
1 parent 2939153 commit a9d9007

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed

tests/test_engine_core.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,3 +870,121 @@ def test_engines_share_mlx_executor(self, mock_model, mock_tokenizer):
870870
finally:
871871
engine1.close()
872872
engine2.close()
873+
874+
@pytest.mark.asyncio
875+
async def test_shared_executor_serializes_concurrent_tasks(self):
876+
"""Concurrent submissions to shared executor must never overlap (#85).
877+
878+
Simulates two engines submitting work simultaneously and verifies
879+
that tasks run one at a time (no concurrent execution).
880+
"""
881+
import threading
882+
import time
883+
from omlx.engine_core import get_mlx_executor
884+
885+
executor = get_mlx_executor()
886+
loop = asyncio.get_running_loop()
887+
888+
active_count = 0
889+
max_concurrent = 0
890+
lock = threading.Lock()
891+
892+
def simulated_step(task_id: str, duration: float = 0.05):
893+
"""Simulate a scheduler.step() that takes some time."""
894+
nonlocal active_count, max_concurrent
895+
with lock:
896+
active_count += 1
897+
if active_count > max_concurrent:
898+
max_concurrent = active_count
899+
time.sleep(duration)
900+
with lock:
901+
active_count -= 1
902+
return task_id
903+
904+
# Submit multiple tasks concurrently (simulating two engines)
905+
tasks = [
906+
loop.run_in_executor(executor, simulated_step, "engine_a_step1"),
907+
loop.run_in_executor(executor, simulated_step, "engine_b_step1"),
908+
loop.run_in_executor(executor, simulated_step, "engine_a_step2"),
909+
loop.run_in_executor(executor, simulated_step, "engine_b_step2"),
910+
]
911+
results = await asyncio.gather(*tasks)
912+
913+
# All tasks completed
914+
assert set(results) == {
915+
"engine_a_step1", "engine_b_step1",
916+
"engine_a_step2", "engine_b_step2",
917+
}
918+
# Critical: no two tasks ever ran at the same time
919+
assert max_concurrent == 1, (
920+
f"Expected max 1 concurrent task, got {max_concurrent}. "
921+
f"Shared executor failed to serialize MLX operations."
922+
)
923+
924+
@pytest.mark.asyncio
925+
async def test_two_engine_loops_serialize_on_shared_executor(
926+
self, mock_model, mock_tokenizer
927+
):
928+
"""Two engines running their loops must serialize step() calls (#85).
929+
930+
Creates two EngineCore instances with mock schedulers, starts both
931+
engine loops, and verifies their scheduler.step() calls never overlap.
932+
"""
933+
import threading
934+
import time
935+
936+
active_count = 0
937+
max_concurrent = 0
938+
total_steps = 0
939+
lock = threading.Lock()
940+
941+
def make_tracked_step():
942+
"""Create a step function that tracks concurrency."""
943+
from omlx.scheduler import SchedulerOutput
944+
945+
def tracked_step():
946+
nonlocal active_count, max_concurrent, total_steps
947+
with lock:
948+
active_count += 1
949+
total_steps += 1
950+
if active_count > max_concurrent:
951+
max_concurrent = active_count
952+
time.sleep(0.01) # Simulate GPU work
953+
with lock:
954+
active_count -= 1
955+
return SchedulerOutput(outputs=[])
956+
957+
return tracked_step
958+
959+
with patch("omlx.engine_core.get_registry") as mock_registry:
960+
mock_registry.return_value.acquire.return_value = True
961+
962+
engine1 = EngineCore(model=mock_model, tokenizer=mock_tokenizer)
963+
engine2 = EngineCore(model=mock_model, tokenizer=mock_tokenizer)
964+
965+
# Wire up tracked step functions
966+
engine1.scheduler.step = make_tracked_step()
967+
engine2.scheduler.step = make_tracked_step()
968+
engine1.scheduler.has_requests = lambda: True
969+
engine2.scheduler.has_requests = lambda: True
970+
971+
try:
972+
await engine1.start()
973+
await engine2.start()
974+
975+
# Let both engines run for a bit
976+
await asyncio.sleep(0.3)
977+
finally:
978+
await engine1.stop()
979+
await engine2.stop()
980+
engine1.close()
981+
engine2.close()
982+
983+
assert total_steps >= 4, (
984+
f"Expected at least 4 steps from two engines, got {total_steps}"
985+
)
986+
assert max_concurrent == 1, (
987+
f"Expected max 1 concurrent step(), got {max_concurrent}. "
988+
f"Two engines ran MLX operations in parallel — would cause "
989+
f"Metal command buffer races in production."
990+
)

0 commit comments

Comments
 (0)