Skip to content

Commit 0a721da

Browse files
authored
Support for Continuous Batching (microsoft#1580)
Batching is a key technique for improving throughput in large language models by allowing multiple requests to be processed simultaneously. By grouping requests, models can leverage parallel processing to handle more queries per unit of time. Currently, `onnxruntime-genai` performs inference using **static batching** ([see explanation of static vs. dynamic batching here](https://www.anyscale.com/blog/continuous-batching-llm-inference#llm-batching-explained)). In this mode, the batch state is tightly coupled to the `Generator`, which restricts the ability to process multiple requests dynamically or concurrently. To enable **dynamic batching** in `onnxruntime-genai`, this pull-request introduces several key concepts that decouple the batch state from the generation engine, paving the way for more efficient request handling. ### What's New? #### OgaRequest The `OgaRequest` encapsulates the state of the application request. It tracks the progress of the request, including the input tokens, generation parameters, and the current state of the generation process. Requests are processed concurrently by the Engine, which dynamically batches them for efficient model execution. #### OgaEngine The `OgaEngine` manages concurrent requests. It is responsible for scheduling them and dynamically batching them in addition to running the core generation logic (model execution). ### Example Usage #### Python ```python import onnxruntime_genai as og model = og.Model(config) engine = og.Engine(model) tokenizer = og.Tokenizer(model) streaming_tokenizer = tokenizer.create_stream() params = og.GeneratorParams(model) request = og.Request(params) request.add_tokens(tokenizer.encode(tokenizer.apply_chat_template(...))) engine.add_request(request) while ready_request := engine.step(): while ready_request.has_unseen_tokens(): print(streaming_tokenizer.decode(ready_request.get_unseen_token()), end="", flush=True) engine.remove_request(request) ``` #### C++ ```cpp #include "ort_genai.h" auto model = OgaModel::Create(PHI2_PATH); auto engine = OgaEngine::Create(*model); auto tokenizer = OgaTokenizer::Create(*model); auto streaming_tokenizer = OgaTokenizerStream::Create(*tokenizer); const std::string prompt = GetPrompt(...); auto sequence = OgaSequences::Create(); tokenizer->Encode(prompt .c_str(), *sequence); auto params = OgaGeneratorParams::Create(*model); auto request = OgaRequest::Create(*params); request->AddTokens(*sequence); engine->Add(*request); while (auto ready_request = engine->Step()) { while (ready_request->HasUnseenTokens()) { std::cout << streaming_tokenizer->Decode(ready_request->GetUnseenToken()); } } engine->Remove(*request); ``` ### What About the OgaGenerator? This pull-request decouples the batch state from the generation engine, addressing one of the main limitations of the `OgaGenerator` which is its tight coupling with the batch state. This design constraint restricted the `OgaGenerator` to static batching only. With the introduction of `OgaEngine` and `OgaRequest`, the need for `OgaGenerator` in this role is effectively eliminated. ### OgaEngine v/s. OgaGenerator | Feature | `OgaEngine` | `OgaGenerator` | |----------|----------|----------| | Continuous Static Batching | ✅ | ❌| | Continuous Dynamic Batching | ⏳ | ❌| | Language Models | ✅ | ✅ | | Other Model Types | ⏳ | ✅ | | Beam Search | ⏳ | ✅ | | Adapters Support | ⏳ | ✅ | | DML Support | ❌ | ✅ | | Cuda, CPU, WebGPU Support | ✅ | ✅ | | Other EPs | ⏳ | ✅ |
1 parent 87be827 commit 0a721da

31 files changed

+2268
-27
lines changed

cmake/global_variables.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ set(REPO_ROOT ${PROJECT_SOURCE_DIR})
3131
set(SRC_ROOT ${REPO_ROOT}/src)
3232
set(GENERATORS_ROOT ${SRC_ROOT})
3333
set(MODELS_ROOT ${SRC_ROOT}/models)
34+
set(ENGINE_ROOT ${SRC_ROOT}/engine)
3435

3536
# Define the dependency libraries
3637

@@ -79,6 +80,10 @@ file(GLOB generator_srcs CONFIGURE_DEPENDS
7980
"${GENERATORS_ROOT}/openvino/*.cpp"
8081
"${MODELS_ROOT}/*.h"
8182
"${MODELS_ROOT}/*.cpp"
83+
"${ENGINE_ROOT}/*.h"
84+
"${ENGINE_ROOT}/*.cpp"
85+
"${ENGINE_ROOT}/decoders/*.h"
86+
"${ENGINE_ROOT}/decoders/*.cpp"
8287
)
8388

8489
set(ortgenai_embed_libs "") # shared libs that will be embedded inside the onnxruntime-genai package
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import argparse
5+
import json
6+
import random
7+
import threading
8+
import time
9+
10+
import onnxruntime_genai as og
11+
import tqdm
12+
from datasets import load_dataset
13+
14+
15+
def get_random_prompts(num_questions: int, split="validation") -> list[str]:
16+
dataset = load_dataset("squad_v2", split=split)
17+
questions = [item["question"] for item in dataset]
18+
return random.sample(questions, min(num_questions, len(questions)))
19+
20+
21+
class ClientRequest:
22+
def __init__(
23+
self, prompt: str, model: og.Model, tokenizer: og.Tokenizer, opaque_data: any
24+
):
25+
self.prompt = prompt
26+
self.params = og.GeneratorParams(model)
27+
self.params.set_search_options(
28+
do_sample=False,
29+
max_length=256,
30+
)
31+
32+
messages = [
33+
{"role": "system", "content": ""},
34+
{"role": "user", "content": f"{prompt}"},
35+
]
36+
messages = json.dumps(messages)
37+
38+
self.request = og.Request(self.params)
39+
self.request.add_tokens(
40+
tokenizer.encode(
41+
tokenizer.apply_chat_template(
42+
messages=messages, add_generation_prompt=True
43+
)
44+
)
45+
)
46+
self.request.set_opaque_data(opaque_data)
47+
self.streaming_tokenizer = tokenizer.create_stream()
48+
self.token_stream = ""
49+
50+
51+
class RequestPool:
52+
def __init__(
53+
self,
54+
model: og.Model,
55+
tokenizer: og.Tokenizer,
56+
engine: og.Engine,
57+
num_requests: int,
58+
load_factor: float = 0.2,
59+
debug: bool = False,
60+
):
61+
self.model = model
62+
self.tokenizer = tokenizer
63+
self.engine = engine
64+
self.num_requests = num_requests
65+
self.requests: list[ClientRequest] = []
66+
self.prompts = get_random_prompts(num_requests)
67+
self.load_factor = load_factor
68+
self.lock = threading.Lock()
69+
self.bar = tqdm.tqdm(total=len(self.prompts))
70+
self.debug = debug
71+
72+
# Add load_factor * num_requests requests to the engine
73+
for prompt in self.prompts[: int(num_requests * load_factor)]:
74+
request = ClientRequest(prompt, model, tokenizer, self)
75+
self.requests.append(request)
76+
self.engine.add_request(request.request)
77+
78+
def fill(self):
79+
for i, prompt in enumerate(
80+
self.prompts[int(len(self.prompts) * self.load_factor) :]
81+
):
82+
request = ClientRequest(prompt, self.model, self.tokenizer, self)
83+
with self.lock:
84+
self.requests.append(request)
85+
self.engine.add_request(request.request)
86+
time.sleep(1) # Simulate some delay in request generation
87+
88+
def drain(self, request: og.Request):
89+
with self.lock:
90+
client_request = next(
91+
(r for r in self.requests if r.request == request), None
92+
)
93+
while request.has_unseen_tokens():
94+
token = request.get_unseen_token()
95+
client_request.token_stream += (
96+
client_request.streaming_tokenizer.decode(token)
97+
)
98+
99+
if request.is_done():
100+
assert (
101+
client_request is not None
102+
), "Client request not found in the pool"
103+
104+
if self.debug:
105+
print(f"🫵 : {client_request.prompt}")
106+
print(f"🤖 : {client_request.token_stream}")
107+
self.engine.remove_request(request)
108+
self.requests.remove(client_request)
109+
self.bar.update(1)
110+
111+
112+
class Engine:
113+
def __init__(self, model_path: str, execution_provider: str, debug: bool):
114+
self.config = og.Config(model_path)
115+
self.config.clear_providers()
116+
if execution_provider != "cpu":
117+
self.config.append_provider(execution_provider)
118+
self.model = og.Model(self.config)
119+
self.tokenizer = og.Tokenizer(self.model)
120+
self.engine = og.Engine(self.model)
121+
self.debug = debug
122+
self.tokens_decoded = 0
123+
124+
def run(self):
125+
while request := self.engine.step():
126+
request_pool = request.get_opaque_data()
127+
request_pool.drain(request)
128+
self.tokens_decoded += 1
129+
130+
131+
def run(args: argparse.Namespace):
132+
engine = Engine(args.model_path, args.execution_provider, args.debug)
133+
request_pool = RequestPool(
134+
engine.model,
135+
engine.tokenizer,
136+
engine.engine,
137+
args.num_requests,
138+
debug=args.debug,
139+
)
140+
141+
producer_thread = threading.Thread(target=request_pool.fill)
142+
producer_thread.start()
143+
144+
start = time.time()
145+
engine.run()
146+
end = time.time()
147+
148+
request_pool.bar.close()
149+
print(f"⌛Tokens per second: {engine.tokens_decoded / (end - start):.2f}")
150+
151+
152+
if __name__ == "__main__":
153+
parser = argparse.ArgumentParser(
154+
description="End-to-end AI Question/Answer example for gen-ai",
155+
)
156+
parser.add_argument(
157+
"-m",
158+
"--model_path",
159+
type=str,
160+
required=True,
161+
help="Onnx model folder path (must contain genai_config.json and model.onnx)",
162+
)
163+
parser.add_argument(
164+
"-e",
165+
"--execution_provider",
166+
type=str,
167+
required=True,
168+
choices=["cpu", "cuda", "dml", "webgpu"],
169+
help="Execution provider to run ONNX model with",
170+
)
171+
parser.add_argument(
172+
"-d",
173+
"--debug",
174+
action="store_true",
175+
help="Enable debug logging",
176+
)
177+
parser.add_argument(
178+
"-n",
179+
"--num_requests",
180+
type=int,
181+
default=1,
182+
help="Number of requests to process in the pool",
183+
)
184+
args = parser.parse_args()
185+
186+
run(args)

examples/python/engine/model-qa.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import argparse
5+
import json
6+
7+
import onnxruntime_genai as og
8+
9+
10+
def run(args: argparse.Namespace):
11+
config = og.Config(args.model_path)
12+
config.clear_providers()
13+
if args.execution_provider != "cpu":
14+
config.append_provider(args.execution_provider)
15+
16+
model = og.Model(config)
17+
tokenizer = og.Tokenizer(model)
18+
engine = og.Engine(model)
19+
20+
while prompt := input("🫵 : "):
21+
if prompt == "/exit":
22+
break
23+
24+
messages = [
25+
{"role": "system", "content": ""},
26+
{"role": "user", "content": f"{prompt}"},
27+
]
28+
messages = json.dumps(messages)
29+
30+
params = og.GeneratorParams(model)
31+
params.set_search_options(
32+
do_sample=False,
33+
max_length=1024,
34+
)
35+
36+
request = og.Request(params)
37+
request.add_tokens(
38+
tokenizer.encode(
39+
tokenizer.apply_chat_template(
40+
messages=messages, add_generation_prompt=True
41+
)
42+
),
43+
)
44+
streaming_tokenizer = tokenizer.create_stream()
45+
46+
engine.add_request(request)
47+
48+
print(f"🤖 :", end="", flush=True)
49+
50+
while ready_request := engine.step():
51+
while ready_request.has_unseen_tokens():
52+
print(
53+
streaming_tokenizer.decode(ready_request.get_unseen_token()),
54+
end="",
55+
flush=True,
56+
)
57+
58+
print()
59+
engine.remove_request(request)
60+
61+
62+
if __name__ == "__main__":
63+
parser = argparse.ArgumentParser(
64+
description="End-to-end AI Question/Answer example for gen-ai",
65+
)
66+
parser.add_argument(
67+
"-m",
68+
"--model_path",
69+
type=str,
70+
required=True,
71+
help="Onnx model folder path (must contain genai_config.json and model.onnx)",
72+
)
73+
parser.add_argument(
74+
"-e",
75+
"--execution_provider",
76+
type=str,
77+
required=True,
78+
choices=["cpu", "cuda", "dml", "webgpu"],
79+
help="Execution provider to run ONNX model with",
80+
)
81+
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
82+
83+
args = parser.parse_args()
84+
if args.debug:
85+
og.set_log_options(
86+
enabled=True,
87+
model_input_values=True,
88+
model_output_values=True,
89+
model_output_shapes=True,
90+
)
91+
92+
run(args)

src/config.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ struct Config {
2323
static constexpr std::string_view PresentValueName = "present.%d.value";
2424
static constexpr std::string_view RnnStatesName = "rnn_states";
2525
static constexpr std::string_view RnnStatesPrevName = "rnn_states_prev";
26+
static constexpr std::string_view CumulativeSequenceLengthsName = "cumulative_sequence_lengths";
27+
static constexpr std::string_view SequenceLengthsName = "sequence_lengths";
28+
static constexpr std::string_view PastSequenceLengthsName = "past_sequence_lengths";
29+
static constexpr std::string_view BlockTableName = "block_table";
2630

2731
// Speech encoder names
2832
static constexpr std::string_view AudioAttentionMaskName = "audio_attention_mask";
@@ -201,7 +205,6 @@ struct Config {
201205
std::string past_value_names{Defaults::PastValueName};
202206
std::string past_names; // When key/value pairs are combined
203207
std::string cross_past_key_names, cross_past_value_names;
204-
205208
std::string past_key_values_length{Defaults::PastKeyValuesLengthName};
206209
std::string past_sequence_length{Defaults::PastSequenceLengthName};
207210
std::string current_sequence_length{Defaults::CurrentSequenceLengthName};
@@ -210,6 +213,9 @@ struct Config {
210213
std::string encoder_hidden_states{Defaults::EncoderHiddenStatesName};
211214
std::string rnn_prev_states{Defaults::RnnStatesPrevName};
212215
std::string encoder_attention_mask{Defaults::EncoderAttentionMaskName};
216+
std::string cumulative_sequence_lengths{Defaults::CumulativeSequenceLengthsName};
217+
std::string past_sequence_lengths{Defaults::PastSequenceLengthsName};
218+
std::string block_table{Defaults::BlockTableName};
213219
} inputs;
214220

215221
struct Outputs {

0 commit comments

Comments
 (0)