Skip to content

Commit 57d388e

Browse files
committed
revert diffusion_trt.py
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 4561de9 commit 57d388e

File tree

1 file changed

+6
-85
lines changed

1 file changed

+6
-85
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 6 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
update_dynamic_axes,
2424
)
2525
from quantize import ModelType, PipelineManager
26-
from tqdm import tqdm
2726

2827
import modelopt.torch.opt as mto
2928
from modelopt.torch._deploy._runtime import RuntimeRegistry
@@ -59,59 +58,6 @@ def generate_image(pipe, prompt, image_name):
5958
print(f"Image generated saved as {image_name}")
6059

6160

62-
def benchmark_model(
63-
pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype="Half"
64-
):
65-
"""Benchmark the backbone model inference time."""
66-
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
67-
68-
backbone_times = []
69-
start_event = torch.cuda.Event(enable_timing=True)
70-
end_event = torch.cuda.Event(enable_timing=True)
71-
72-
def forward_pre_hook(_module, _input):
73-
start_event.record()
74-
75-
def forward_hook(_module, _input, _output):
76-
end_event.record()
77-
torch.cuda.synchronize()
78-
backbone_times.append(start_event.elapsed_time(end_event))
79-
80-
pre_handle = backbone.register_forward_pre_hook(forward_pre_hook)
81-
post_handle = backbone.register_forward_hook(forward_hook)
82-
83-
try:
84-
print(f"Starting warmup: {num_warmup} runs")
85-
for _ in tqdm(range(num_warmup), desc="Warmup"):
86-
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
87-
_ = pipe(
88-
prompt,
89-
output_type="pil",
90-
num_inference_steps=num_inference_steps,
91-
generator=torch.Generator("cuda").manual_seed(42),
92-
)
93-
94-
backbone_times.clear()
95-
96-
print(f"Starting benchmark: {num_runs} runs")
97-
for _ in tqdm(range(num_runs), desc="Benchmark"):
98-
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
99-
_ = pipe(
100-
prompt,
101-
output_type="pil",
102-
num_inference_steps=num_inference_steps,
103-
generator=torch.Generator("cuda").manual_seed(42),
104-
)
105-
finally:
106-
pre_handle.remove()
107-
post_handle.remove()
108-
109-
total_backbone_time = sum(backbone_times)
110-
avg_latency = total_backbone_time / (num_runs * num_inference_steps)
111-
print(f"Inference latency of the torch backbone: {avg_latency:.2f} ms")
112-
return avg_latency
113-
114-
11561
def main():
11662
parser = argparse.ArgumentParser()
11763
parser.add_argument(
@@ -146,24 +92,15 @@ def main():
14692
"--onnx-load-path", type=str, default="", help="Path to load the ONNX model"
14793
)
14894
parser.add_argument(
149-
"--trt-engine-load-path", type=str, default=None, help="Path to load the TensorRT engine"
95+
"--trt-engine-load-path", type=str, default=None, help="Path to load the TRT engine"
15096
)
15197
parser.add_argument(
15298
"--dq-only", action="store_true", help="Converts the ONNX model to a dq_only model"
15399
)
154100
parser.add_argument(
155-
"--torch",
156-
action="store_true",
157-
help="Use the torch pipeline for image generation or benchmarking",
101+
"--torch", action="store_true", help="Generate an image using the torch pipeline"
158102
)
159103
parser.add_argument("--save-image-as", type=str, default=None, help="Name of the image to save")
160-
parser.add_argument(
161-
"--benchmark", action="store_true", help="Benchmark the model backbone inference time"
162-
)
163-
parser.add_argument(
164-
"--torch-compile", action="store_true", help="Use torch.compile() on the backbone model"
165-
)
166-
parser.add_argument("--skip-image", action="store_true", help="Skip image generation")
167104
args = parser.parse_args()
168105

169106
image_name = args.save_image_as if args.save_image_as else f"{args.model}.png"
@@ -188,25 +125,13 @@ def main():
188125
if args.restore_from:
189126
mto.restore(backbone, args.restore_from)
190127

191-
if args.torch_compile:
192-
assert args.model_dtype in ["BFloat16", "Float", "Half"], (
193-
"torch.compile() only supports BFloat16 and Float"
194-
)
195-
print("Compiling backbone with torch.compile()...")
196-
backbone = torch.compile(backbone, mode="max-autotune")
197-
198128
if args.torch:
199129
if hasattr(pipe, "transformer"):
200130
pipe.transformer = backbone
201131
elif hasattr(pipe, "unet"):
202132
pipe.unet = backbone
203133
pipe.to("cuda")
204-
205-
if args.benchmark:
206-
benchmark_model(pipe, args.prompt, model_dtype=args.model_dtype)
207-
208-
if not args.skip_image:
209-
generate_image(pipe, args.prompt, image_name)
134+
generate_image(pipe, args.prompt, image_name)
210135
return
211136

212137
backbone.to("cuda")
@@ -286,14 +211,10 @@ def main():
286211
raise ValueError("Pipeline does not have a transformer or unet backbone")
287212
pipe.to("cuda")
288213

289-
if not args.skip_image:
290-
generate_image(pipe, args.prompt, image_name)
291-
print(f"Image generated using {args.model} model saved as {image_name}")
214+
generate_image(pipe, args.prompt, image_name)
215+
print(f"Image generated using {args.model} model saved as {image_name}")
292216

293-
if args.benchmark:
294-
print(
295-
f"Inference latency of the TensorRT optimized backbone: {device_model.get_latency()} ms"
296-
)
217+
print(f"Inference latency of the backbone of the pipeline is {device_model.get_latency()} ms")
297218

298219

299220
if __name__ == "__main__":

0 commit comments

Comments
 (0)