Skip to content

Commit 8a999e2

Browse files
nvidia-modelopt 0.15.1 examples release
1 parent 2a3f7cf commit 8a999e2

File tree

4 files changed

+135
-15
lines changed

4 files changed

+135
-15
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
#
4+
# Permission is hereby granted, free of charge, to any person obtaining a
5+
# copy of this software and associated documentation files (the "Software"),
6+
# to deal in the Software without restriction, including without limitation
7+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8+
# and/or sell copies of the Software, and to permit persons to whom the
9+
# Software is furnished to do so, subject to the following conditions:
10+
#
11+
# The above copyright notice and this permission notice shall be included in
12+
# all copies or substantial portions of the Software.
13+
#
14+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20+
# DEALINGS IN THE SOFTWARE.
21+
22+
import argparse
23+
import time
24+
from pathlib import Path
25+
26+
import torch
27+
from cache_diffusion import cachify
28+
from cache_diffusion.utils import SDXL_DEFAULT_CONFIG
29+
from diffusers import DiffusionPipeline
30+
from pipeline.deploy import compile, teardown
31+
32+
33+
def parse_args():
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument("--batch-size", type=int, default=2)
36+
parser.add_argument("--num-inference-steps", type=int, default=30)
37+
parser.add_argument("--num-iter", type=int, default=8)
38+
args = parser.parse_args()
39+
for key, value in vars(args).items():
40+
if value is not None:
41+
print("Parsed args -- {}: {}".format(key, value))
42+
return args
43+
44+
45+
def main(args):
46+
pipe = DiffusionPipeline.from_pretrained(
47+
"stabilityai/stable-diffusion-xl-base-1.0",
48+
torch_dtype=torch.float16,
49+
variant="fp16",
50+
use_safetensors=True,
51+
)
52+
pipe = pipe.to("cuda")
53+
54+
prompt = "A random person with a head that is made of flowers, photo by James C. Leyendecker, \
55+
Afrofuturism, studio portrait, dynamic pose, national geographic photo, retrofuturism, \
56+
biomorphicy"
57+
58+
compile(
59+
pipe.unet,
60+
onnx_path=Path("./onnx"),
61+
engine_path=Path("./engine"),
62+
batch_size=args.batch_size,
63+
)
64+
65+
cachify.prepare(pipe, args.num_inference_steps, SDXL_DEFAULT_CONFIG)
66+
67+
generator = torch.Generator(device="cuda").manual_seed(2946901)
68+
total_time = 0
69+
cachify.disable(pipe)
70+
for _ in range(args.num_iter):
71+
start_time = time.time()
72+
_ = pipe(
73+
prompt=[prompt] * args.batch_size,
74+
num_inference_steps=args.num_inference_steps,
75+
generator=generator,
76+
)
77+
end_time = time.time()
78+
total_time += end_time - start_time
79+
total_time = total_time / args.num_iter
80+
latency = total_time / args.batch_size
81+
print(f"TRT Disabled Cache: {latency}")
82+
83+
generator = torch.Generator(device="cuda").manual_seed(2946901)
84+
total_time = 0
85+
cachify.enable(pipe)
86+
for _ in range(args.num_iter):
87+
start_time = time.time()
88+
_ = pipe(
89+
prompt=[prompt] * args.batch_size,
90+
num_inference_steps=args.num_inference_steps,
91+
generator=generator,
92+
)
93+
end_time = time.time()
94+
total_time += end_time - start_time
95+
total_time = total_time / args.num_iter
96+
latency = total_time / args.batch_size
97+
print(f"TRT Enabled Cache: {latency}")
98+
teardown(pipe.unet)
99+
100+
101+
if __name__ == "__main__":
102+
args = parse_args()
103+
main(args)

diffusers/cache_diffusion/cache_diffusion/cachify.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,23 @@ def cachify(model, num_inference_steps, config_list, modules):
8787

8888
def disable(pipe):
8989
model = get_model(pipe)
90+
if hasattr(model, "use_trt_infer") and model.use_trt_infer:
91+
for _, module in model.engines.items():
92+
if isinstance(module, CachedModule):
93+
module.disable_cache()
94+
return
9095
for _, module in model.named_modules():
9196
if isinstance(module, CachedModule):
9297
module.disable_cache()
9398

9499

95100
def enable(pipe):
96101
model = get_model(pipe)
102+
if hasattr(model, "use_trt_infer") and model.use_trt_infer:
103+
for _, module in model.engines.items():
104+
if isinstance(module, CachedModule):
105+
module.enable_cache()
106+
return
97107
for _, module in model.named_modules():
98108
if isinstance(module, CachedModule):
99109
module.enable_cache()

diffusers/cache_diffusion/pipeline/deploy.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,15 @@ def replace_new_forward(unet):
5454
upsample_block.forward = types.MethodType(cacheupblock2d_forward, upsample_block)
5555

5656

57-
def get_input_info(dummy_dict, info=None):
57+
def get_input_info(dummy_dict, info: str = None, batch_size: int = 1):
5858
return_val = [] if info == "profile_shapes" or info == "input_names" else {}
5959

6060
def collect_leaf_keys(d):
6161
for key, value in d.items():
6262
if isinstance(value, dict):
6363
collect_leaf_keys(value)
6464
else:
65+
value = (value[0] * batch_size,) + value[1:]
6566
if info == "profile_shapes":
6667
return_val.append((key, value)) # type: ignore
6768
elif info == "profile_shapes_dict":
@@ -75,7 +76,7 @@ def collect_leaf_keys(d):
7576
return return_val
7677

7778

78-
def complie2trt(onnx_path: Path, engine_path: Path):
79+
def complie2trt(onnx_path: Path, engine_path: Path, batch_size: int = 1):
7980
subdirs = [f for f in onnx_path.iterdir() if f.is_dir()]
8081
for subdir in subdirs:
8182
if subdir.name not in SDXL_ONNX_CONFIG.keys():
@@ -86,15 +87,17 @@ def complie2trt(onnx_path: Path, engine_path: Path):
8687
print(f"Building {str(model_path)}")
8788
build_profile = Profile()
8889
profile_shapes = get_input_info(
89-
SDXL_ONNX_CONFIG[subdir.name]["dummy_input"], "profile_shapes"
90+
SDXL_ONNX_CONFIG[subdir.name]["dummy_input"], "profile_shapes", batch_size
9091
)
9192
for input_name, input_shape in profile_shapes:
92-
build_profile.add(input_name, input_shape, input_shape, input_shape)
93+
min_input_shape = (2,) + input_shape[1:]
94+
build_profile.add(input_name, min_input_shape, input_shape, input_shape)
9395
block_network = network_from_onnx_path(
94-
str(model_path), flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]
96+
str(model_path), flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM], strongly_typed=True
9597
)
9698
build_config = CreateConfig(
97-
fp16=True,
99+
builder_optimization_level=4,
100+
tf32=True,
98101
profiles=[build_profile],
99102
)
100103
engine = engine_from_network(
@@ -113,7 +116,7 @@ def get_total_device_memory(unet):
113116
return max_device_memory
114117

115118

116-
def load_engines(unet, engine_path: Path):
119+
def load_engines(unet, engine_path: Path, batch_size: int = 1):
117120
unet.engines = {}
118121
for f in engine_path.iterdir():
119122
if f.is_file():
@@ -127,9 +130,10 @@ def load_engines(unet, engine_path: Path):
127130
for block_name in unet.engines.keys():
128131
unet.engines[block_name].allocate_buffers(
129132
shape_dict=get_input_info(
130-
SDXL_ONNX_CONFIG[block_name]["dummy_input"], "profile_shapes_dict"
133+
SDXL_ONNX_CONFIG[block_name]["dummy_input"], "profile_shapes_dict", batch_size
131134
),
132135
device=unet.device,
136+
batch_size=batch_size,
133137
)
134138
# TODO: Free and clean up the origin pytorch cuda memory
135139

@@ -216,10 +220,12 @@ def export_onnx(unet, onnx_path: Path):
216220
print(f"{str(_onnx_file)} alread exists!")
217221

218222

219-
def warm_up(unet):
223+
def warm_up(unet, batch_size: int = 1):
220224
print("Warming-up TensorRT engines...")
221225
for name, engine in unet.engines.items():
222-
dummy_input = get_input_info(SDXL_ONNX_CONFIG[name]["dummy_input"], "dummy_input")
226+
dummy_input = get_input_info(
227+
SDXL_ONNX_CONFIG[name]["dummy_input"], "dummy_input", batch_size
228+
)
223229
_ = engine(dummy_input, unet.cuda_stream)
224230

225231

@@ -231,13 +237,13 @@ def teardown(unet):
231237
del unet.cuda_stream
232238

233239

234-
def compile(unet, onnx_path: Path, engine_path: Path):
240+
def compile(unet, onnx_path: Path, engine_path: Path, batch_size: int = 1):
235241
onnx_path.mkdir(parents=True, exist_ok=True)
236242
engine_path.mkdir(parents=True, exist_ok=True)
237243

238244
replace_new_forward(unet)
239245
export_onnx(unet, onnx_path)
240-
complie2trt(onnx_path, engine_path)
241-
load_engines(unet, engine_path)
242-
warm_up(unet)
246+
complie2trt(onnx_path, engine_path, batch_size)
247+
load_engines(unet, engine_path, batch_size)
248+
warm_up(unet, batch_size)
243249
unet.use_trt_infer = True

diffusers/cache_diffusion/pipeline/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,14 @@ def activate(self, reuse_device_memory=None):
7171
else:
7272
self.context = self.engine.create_execution_context() # type: ignore
7373

74-
def allocate_buffers(self, shape_dict=None, device="cuda"):
74+
def allocate_buffers(self, shape_dict=None, device="cuda", batch_size=1):
7575
for binding in range(self.engine.num_io_tensors): # type: ignore
7676
name = self.engine.get_tensor_name(binding) # type: ignore
7777
if shape_dict and name in shape_dict:
7878
shape = shape_dict[name]
7979
else:
8080
shape = self.engine.get_tensor_shape(name) # type: ignore
81+
shape = (batch_size * 2,) + shape[1:]
8182
dtype = trt.nptype(self.engine.get_tensor_dtype(name)) # type: ignore
8283
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: # type: ignore
8384
self.context.set_input_shape(name, shape) # type: ignore

0 commit comments

Comments
 (0)