Skip to content

Commit 742595b

Browse files
authored
Speedup Llama2 cpu throughput in bench by 1.69x with iobinding (#19853)
### Description Always set `use_io_binding=True` when using optimum.onnxruntime unless there is a special case. ### Motivation and Context By default, `ORTModel` under optimum.onnxruntime will choose the appropriate `use_io_binding` value based on provider and use cases. > use_io_binding (`Optional[bool]`, defaults to `None`): > Whether to use IOBinding during inference to avoid memory copy between the host and device, or between numpy/torch tensors and ONNX Runtime ORTValue. Defaults to > `True` if the execution provider is CUDAExecutionProvider. For [~onnxruntime.ORTModelForCausalLM], defaults to `True` on CPUExecutionProvider, > in all other cases defaults to `False`. For Llama token benchmark, using iobinding yields almost 2x speedup, even on CPU. This is because this particular model yields a large number of outputs (>60). Without iobinding, a copy is performed for each output from ortvalue to numpy array. This adds significant overhead to the overall run time. ``` Evaluating Llama2 `model(inputs)` step with past_key_values Before, w/o iobinding on cpu Batch Size: 1 Sequence Length: 512 Latency: 0.4518657898902893 s Throughput: 2.2130464894073856 tps After, w/ iobinding on cpu Batch Size: 1 Sequence Length: 512 Latency: 0.2662619352340698 s Throughput: 3.7557001871893703 tps ```
1 parent d4fa4f0 commit 742595b

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

onnxruntime/python/tools/transformers/models/llama/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def get_model(args: argparse.Namespace):
243243
decoder_file_name=decoder_file_name,
244244
decoder_with_past_file_name=decoder_with_past_file_name,
245245
use_auth_token=args.auth,
246-
use_io_binding=(args.device != "cpu"),
246+
use_io_binding=True, # Large perf gain even for cpu due to avoiding output copy.
247247
use_merged=(True if decoder_file_name == "model.onnx" else None),
248248
provider=provider,
249249
provider_options=provider_options,

onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,29 +315,29 @@ def get_optimum_ort_pipeline(
315315
directory,
316316
provider=provider,
317317
session_options=None,
318-
use_io_binding=False,
318+
use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification.
319319
)
320320
else:
321321
pipeline = ORTStableDiffusionPipeline.from_pretrained(
322322
directory,
323323
provider=provider,
324-
use_io_binding=False,
324+
use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification.
325325
)
326326
elif "xl" in model_name:
327327
pipeline = ORTStableDiffusionXLPipeline.from_pretrained(
328328
model_name,
329329
export=True,
330330
provider=provider,
331331
session_options=None,
332-
use_io_binding=False,
332+
use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification.
333333
)
334334
pipeline.save_pretrained(directory)
335335
else:
336336
pipeline = ORTStableDiffusionPipeline.from_pretrained(
337337
model_name,
338338
export=True,
339339
provider=provider,
340-
use_io_binding=False,
340+
use_io_binding=False, # Not supported by Optimum version 1.17.1 at the time of verification.
341341
)
342342
pipeline.save_pretrained(directory)
343343

onnxruntime/python/tools/transformers/models/whisper/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,10 @@ def get_model(args: argparse.Namespace):
145145
start_time = time.time()
146146
model = ORTModelForSpeechSeq2Seq.from_pretrained(
147147
args.hf_ort_dir_path,
148-
use_io_binding=(args.device != "cpu"),
149148
provider=provider,
150149
provider_options=provider_options,
151150
session_options=sess_options,
151+
use_io_binding=True, # Avoid memory copy overhead
152152
)
153153
end_time = time.time()
154154

0 commit comments

Comments
 (0)