Skip to content

Commit 99a7e66

Browse files
committed
Add flag for torch.compile()
Signed-off-by: ajrasane <[email protected]>
1 parent b5223b1 commit 99a7e66

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def generate_image(pipe, prompt, image_name):
5959
print(f"Image generated saved as {image_name}")
6060

6161

62-
def benchmark_model(pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20):
62+
def benchmark_model(
63+
pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype="Half"
64+
):
6365
"""Benchmark the backbone model inference time."""
6466
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
6567

@@ -81,23 +83,25 @@ def forward_hook(_module, _input, _output):
8183
try:
8284
print(f"Starting warmup: {num_warmup} runs")
8385
for _ in tqdm(range(num_warmup), desc="Warmup"):
84-
_ = pipe(
85-
prompt,
86-
output_type="pil",
87-
num_inference_steps=num_inference_steps,
88-
generator=torch.Generator("cuda").manual_seed(42),
89-
)
86+
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
87+
_ = pipe(
88+
prompt,
89+
output_type="pil",
90+
num_inference_steps=num_inference_steps,
91+
generator=torch.Generator("cuda").manual_seed(42),
92+
)
9093

9194
backbone_times.clear()
9295

9396
print(f"Starting benchmark: {num_runs} runs")
9497
for _ in tqdm(range(num_runs), desc="Benchmark"):
95-
_ = pipe(
96-
prompt,
97-
output_type="pil",
98-
num_inference_steps=num_inference_steps,
99-
generator=torch.Generator("cuda").manual_seed(42),
100-
)
98+
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
99+
_ = pipe(
100+
prompt,
101+
output_type="pil",
102+
num_inference_steps=num_inference_steps,
103+
generator=torch.Generator("cuda").manual_seed(42),
104+
)
101105
finally:
102106
pre_handle.remove()
103107
post_handle.remove()
@@ -156,6 +160,9 @@ def main():
156160
parser.add_argument(
157161
"--benchmark", action="store_true", help="Benchmark the model backbone inference time"
158162
)
163+
parser.add_argument(
164+
"--torch-compile", action="store_true", help="Use torch.compile() on the backbone model"
165+
)
159166
parser.add_argument("--skip-image", action="store_true", help="Skip image generation")
160167
args = parser.parse_args()
161168

@@ -181,6 +188,13 @@ def main():
181188
if args.restore_from:
182189
mto.restore(backbone, args.restore_from)
183190

191+
if args.torch_compile:
192+
assert args.model_dtype in ["BFloat16", "Float"], (
193+
"torch.compile() only supports BFloat16 and Float"
194+
)
195+
print("Compiling backbone with torch.compile()...")
196+
backbone = torch.compile(backbone)
197+
184198
if args.torch:
185199
if hasattr(pipe, "transformer"):
186200
pipe.transformer = backbone
@@ -189,7 +203,7 @@ def main():
189203
pipe.to("cuda")
190204

191205
if args.benchmark:
192-
benchmark_model(pipe, args.prompt)
206+
benchmark_model(pipe, args.prompt, model_dtype=args.model_dtype)
193207

194208
if not args.skip_image:
195209
generate_image(pipe, args.prompt, image_name)

0 commit comments

Comments
 (0)