Skip to content

[BUG] Llama 3.1 8b Server and SingleStream Scenario fail #2341

@catan2001

Description

@catan2001

Llama 3.1 8b Server and SingleStream Scenario fail

I observed that Server and SingleStream scenarios fail in the MLPerf reference implementation, while the Offline scenario runs fine. This appears to be related to the change made in PR #2314 which changed the SUT used in SingleStream scenario from SUT (Offline ) to SUTServer:

-    sut_map = {"offline": SUT, "server": SUTServer, "singlestream": SUT}
+    sut_map = {"offline": SUT, "server": SUTServer, "singlestream": SUTServer}

As far as I have seen there was an Issue with this one already created #2267 where you discussed that Server scenario will not be used for 5.1v inference submission but this one is for Server only, not SingleStream one also.

After that change, SingleStream also began failing with an EngineDeadError, the same error that affects the Server scenario. Prior to the change, SingleStream used the plain SUT implementation and did not exhibit this crash.

Error in Details:

INFO:Llama-8B-SUT:Loaded model 
INFO:Llama-8B-MAIN:Starting Benchmark run 
Exception in thread Thread-1 (process_queries):
Traceback (most recent call last):
  File "/usr/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/root/inference/language/llama3.1-8b/SUT_VLLM.py", line 275, in process_queries
    asyncio.run(self.stream_output(qitem, results_generator))
  File "/usr/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
  File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
  File "/usr/lib/python3.12/asyncio/base_events.py", line 691, in run_until_complete
    return future.result()
  File "/root/inference/language/llama3.1-8b/SUT_VLLM.py", line 231, in stream_output
    async for request_output in results_generator:
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/async_llm.py", line 321, in generate
    q = await self.add_request(
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/async_llm.py", line 243, in add_request
    raise EngineDeadError()
vllm.v1.engine.exceptions.EngineDeadError: EngineCore encountered an issue.

Root cause

This failure is triggered by the wrong usage of the asyncio library. asyncio is designed to manage asynchronous tasks through a single event loop, but here multiple threads create and run their own loops without coordination.

The sequence looks like this:

  1. process_queries (Thread-1)
    This function is spawned in a new thread by the Server or SingleStream SUTServer class. It takes queries from LoadGen and calls:
asyncio.run(self.stream_output(qitem, results_generator))

asyncio.run() tries to create a fresh event loop for this thread.

  1. stream_output
    This is an async function that streams partial outputs back from the model to LoadGen. Inside, it iterates:
async for request_output in results_generator:

which invokes the vLLM engine generator.

  1. vLLM AsyncLLMEngine.generate
    This calls await self.add_request(...) to enqueue the inference request.

  2. add_request
    This is where the problem shows. The engine already has its own loop and lifecycle, but since process_queries was launched in a separate thread with its own asyncio.run(), the request is being scheduled on a loop that does not belong to the engine. The engine core detects this mismatch and raises EngineDeadError.

The root cause is that every thread created by SUTServer runs asyncio.run() independently. Instead of a single event loop shared across the engine and request-handling tasks, there are multiple isolated loops. Since the vLLM engine is not thread-safe with respect to event loops, requests cannot be submitted reliably, causing the engine to fail.

Potential Solution

A potential solution is to avoid spawning a new asyncio loop per thread and query. Instead, a single global asyncio event loop should be created and kept alive for the lifetime of the SUT or ayncio loop per each thead but not every queue. Worker threads can then submit coroutines to this loop using asyncio.run_coroutine_threadsafe() instead of calling asyncio.run() directly. This ensures all async work is executed in one controlled loop, preventing conflicts inside the engine. The second solution, which is much simpler is to spawn a new asyncio loop per each thread but outside of while loop that waits for queries.

Potential Implementation

class SUTServer(SUT):
    def __init__(
        self,
        model_path=None,
        dtype="bfloat16",
        total_sample_count=13368,
        dataset_path=None,
        batch_size=1,
        workers=1,
        tensor_parallel_size=8,
    ):
        super().__init__(
            model_path=model_path,
            dtype=dtype,
            total_sample_count=total_sample_count,
            dataset_path=dataset_path,
            workers=workers,
            tensor_parallel_size=tensor_parallel_size,
        )
        self.request_id = 0
        self.first_token_queue = queue.Queue()

        # Create a Persistent async loop for whole duration of benchmark
        self.loop = asyncio.new_event_loop()
        self.loop_thread = threading.Thread(
            target=self._run_event_loop, daemon=True
        )
        self.loop_thread.start()

    def _run_event_loop(self):
        asyncio.set_event_loop(self.loop)
        self.loop.run_forever()

    def start(self):
        # Create worker threads
        for j in range(self.num_workers):
            worker = threading.Thread(target=self.process_queries)
            worker.start()
            self.worker_threads[j] = worker

    async def stream_output(self, qitem, results_generator):
        """Async streaming of model outputs"""
        first = True
        async for request_output in results_generator:
            output_response = request_output
            if first:
                first_tokens = list(output_response.outputs[0].token_ids)
                response_data = array.array(
                    "B", np.array(first_tokens, np.int32).tobytes()
                )
                bi = response_data.buffer_info()
                response = [lg.QuerySampleResponse(qitem.id, bi[0], bi[1])]
                lg.FirstTokenComplete(response)
                first = False

        # Final tokens
        outputs = output_response
        pred_output_tokens = list(outputs.outputs[0].token_ids)
        n_tokens = len(pred_output_tokens)
        response_array = array.array(
            "B", np.array(pred_output_tokens, np.int32).tobytes()
        )
        bi = response_array.buffer_info()
        response = [
            lg.QuerySampleResponse(qitem.id, bi[0], bi[1], n_tokens)
        ]
        lg.QuerySamplesComplete(response)

    def process_queries(self):
        """Process queued queries concurrently"""
        while True:
            # Collect up to batch_size queries
            batch = []
            try:
                qitem = self.query_queue.get()
                if qitem is None:
                    break
                batch.append(qitem)

                while len(batch) < self.batch_size:
                    try:
                        qitem = self.query_queue.get_nowait()
                        if qitem is None:
                            break
                        batch.append(qitem)
                    except queue.Empty:
                        break  # queue empty, process what we have
            except Exception:
                continue

            # For each query in the batch, submit async request to vLLM
            futures = []
            for qitem in batch:
                input_ids_tensor = TokensPrompt(
                    prompt_token_ids=self.data_object.input_ids[qitem.index]
                )
                results_generator = self.model.generate(
                    prompt=input_ids_tensor,
                    sampling_params=self.sampling_params,
                    request_id=str(self.request_id),
                )
                self.request_id += 1

                # Schedule streaming of results in persistent loop
                fut = asyncio.run_coroutine_threadsafe(
                    self.stream_output(qitem, results_generator), self.loop
                )
                futures.append(fut)

    def issue_queries(self, query_samples):
        # For Server scenario, LoadGen may give batches.
        # You could extend this to enqueue multiple items at once.
        for q in query_samples:
            self.query_queue.put(q)

    def stop(self):
        for _ in range(self.num_workers):
            self.query_queue.put(None)

        for worker in self.worker_threads:
            worker.join()

        self.loop.call_soon_threadsafe(self.loop.stop)
        self.loop_thread.join()

        self.first_token_queue.put(None)
        # self.ft_response_thread.join() 

    def load_model(self):
        log.info("Loading model")
        self.engine_args = AsyncEngineArgs(
            self.model_path,
            dtype=self.dtype,
            tensor_parallel_size=self.tensor_parallel_size,
        )
        self.model = AsyncLLMEngine.from_engine_args(self.engine_args)
        log.info("Loaded model")

ResultS

================================================
MLPerf Results Summary
================================================
SUT name : PySUT
Scenario : Server
Mode     : PerformanceOnly
Completed samples per second    : 11.07
Completed tokens per second: 1264.46
Result is : VALID
  Performance constraints satisfied : Yes
  Min duration satisfied : Yes
  Min queries satisfied : Yes
  Early stopping satisfied: Yes
TTFT Early Stopping Result:
 * Run successful.
TPOT Early Stopping Result:
 * Run successful.

================================================
Additional Stats
================================================
Scheduled samples per second : 11.08
Min latency (ns)                : 318953861
Max latency (ns)                : 11995097448
Mean latency (ns)               : 4335286743
50.00 percentile latency (ns)   : 4004045558
90.00 percentile latency (ns)   : 6909235804
95.00 percentile latency (ns)   : 8016516656
97.00 percentile latency (ns)   : 8854007961
99.00 percentile latency (ns)   : 10112658342
99.90 percentile latency (ns)   : 11650218811

Completed tokens per second                 : 1264.46
Min First Token latency (ns)                : 20091116
Max First Token latency (ns)                : 1316118108
Mean First Token latency (ns)               : 169476022
50.00 percentile first token latency (ns)   : 126097403
90.00 percentile first token latency (ns)   : 341909500
95.00 percentile first token latency (ns)   : 426958676
97.00 percentile first token latency (ns)   : 493381087
99.00 percentile first token latency (ns)   : 634962570
99.90 percentile first token latency (ns)   : 981248686

Min Time to Output Token (ns)                : 11724943
Max Time to Output Token (ns)                : 103256013
Mean Time to Output Token (ns)               : 36769110
50.00 percentile time to output token (ns)   : 33408832
90.00 percentile time to output token (ns)   : 56929232
95.00 percentile time to output token (ns)   : 66122369
97.00 percentile time to output token (ns)   : 71239265
99.00 percentile time to output token (ns)   : 80095302
99.90 percentile time to output token (ns)   : 88555530

================================================
Test Parameters Used
================================================
samples_per_query : 1
target_qps : 11
ttft_latency (ns): 2000000000
tpot_latency (ns): 100000000
max_async_queries : 0
min_duration (ms): 600000
max_duration (ms): 0
min_query_count : 100
max_query_count : 0
qsl_rng_seed : 1780908523862526354
sample_index_rng_seed : 14771362308971278857
schedule_rng_seed : 18209322760996052031
accuracy_log_rng_seed : 0
accuracy_log_probability : 0
accuracy_log_sampling_target : 0
print_timestamps : 0
performance_issue_unique : 0
performance_issue_same : 0
performance_issue_same_index : 0
performance_sample_count : 13368
WARNING: sample_concatenate_permutation was set to true. 
Generated samples per query might be different as the one in the setting.
Check the generated_samples_per_query line in the detailed log for the real
samples_per_query value

No warnings encountered during test.

No errors encountered during test.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions