Skip to content

Commit f329b19

Browse files
ajrasanekevalmorabia97
authored andcommitted
[NVBUG: 5619158] Optimize memory usage for diffusion_trt.py (#547)
## What does this PR do? **Type of change:** Minor code change **Overview:** - Delete backbone after Device Model creation - Add assertion for torch compile - Update dummy input generation function ## Testing ``` python diffusion_trt.py --model flux-dev --benchmark --skip-image python diffusion_trt.py --model flux-dev --benchmark --skip-image --restore-from ./flux_dev_fp8_autodeploy_fake.pt python diffusion_trt.py --model flux-dev --benchmark --skip-image --restore-from ./flux_dev_fp4_autodeploy_fake.pt python diffusion_trt.py --model flux-dev --benchmark --skip-image --torch python diffusion_trt.py --model flux-dev --benchmark --skip-image --restore-from ./flux_dev_fp8_autodeploy_fake.pt --torch python diffusion_trt.py --model flux-dev --benchmark --skip-image --restore-from ./flux_dev_fp4_autodeploy_fake.pt --torch python diffusion_trt.py --model flux-dev --benchmark --skip-image --torch --torch-compile ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> --------- Signed-off-by: ajrasane <[email protected]>
1 parent 916e1b5 commit f329b19

File tree

3 files changed

+62
-10
lines changed

3 files changed

+62
-10
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def main():
172172
override_model_path=args.override_model_path,
173173
)
174174

175+
if args.torch_compile:
176+
assert args.torch, "Torch mode must be enabled when torch_compile is used"
175177
# Save the backbone of the pipeline and move it to the GPU
176178
add_embedding = None
177179
backbone = None
@@ -186,11 +188,10 @@ def main():
186188
if args.restore_from:
187189
mto.restore(backbone, args.restore_from)
188190

189-
if args.torch_compile:
190-
print("Compiling backbone with torch.compile()...")
191-
backbone = torch.compile(backbone, mode="max-autotune")
192-
193191
if args.torch:
192+
if args.torch_compile:
193+
print("Compiling backbone with torch.compile()...")
194+
backbone = torch.compile(backbone, mode="max-autotune")
194195
if hasattr(pipe, "transformer"):
195196
pipe.transformer = backbone
196197
elif hasattr(pipe, "unet"):
@@ -250,9 +251,15 @@ def main():
250251
dq_only=args.dq_only,
251252
)
252253

254+
# Delete the original backbone and empty the cache
255+
del backbone
256+
torch.cuda.empty_cache()
257+
253258
if not args.trt_engine_load_path:
254259
# Compile the TRT engine from the exported ONNX model
255260
compiled_model = client.ir_to_compiled(onnx_bytes, compilation_args)
261+
# Clear onnx_bytes to free memory
262+
del onnx_bytes
256263
# Save TRT engine for future use
257264
with open(f"{args.model}.plan", "wb") as f:
258265
# Remove the SHA-256 hash from the compiled model, used to maintain state in the trt_client
@@ -276,8 +283,7 @@ def main():
276283
if hasattr(pipe, "unet") and add_embedding:
277284
setattr(device_model, "add_embedding", add_embedding)
278285

279-
# Move the backbone back to the CPU and set the backbone to the compiled device model
280-
backbone.to("cpu")
286+
# Set the backbone to the device model
281287
if hasattr(pipe, "unet"):
282288
pipe.unet = device_model
283289
elif hasattr(pipe, "transformer"):

examples/diffusers/quantization/onnx_utils/export.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ def generate_fp8_scales(backbone):
128128

129129

130130
def _gen_dummy_inp_and_dyn_shapes_sdxl(backbone, min_bs=1, opt_bs=1):
131-
assert isinstance(backbone, UNet2DConditionModel)
131+
assert isinstance(backbone, UNet2DConditionModel) or isinstance(
132+
backbone._orig_mod, UNet2DConditionModel
133+
)
132134
cfg = backbone.config
133135
assert cfg.addition_embed_type == "text_time"
134136

@@ -173,7 +175,9 @@ def _gen_dummy_inp_and_dyn_shapes_sdxl(backbone, min_bs=1, opt_bs=1):
173175

174176

175177
def _gen_dummy_inp_and_dyn_shapes_sd3(backbone, min_bs=1, opt_bs=1):
176-
assert isinstance(backbone, SD3Transformer2DModel)
178+
assert isinstance(backbone, SD3Transformer2DModel) or isinstance(
179+
backbone._orig_mod, SD3Transformer2DModel
180+
)
177181
cfg = backbone.config
178182

179183
dynamic_shapes = {
@@ -205,7 +209,9 @@ def _gen_dummy_inp_and_dyn_shapes_sd3(backbone, min_bs=1, opt_bs=1):
205209

206210

207211
def _gen_dummy_inp_and_dyn_shapes_flux(backbone, min_bs=1, opt_bs=1):
208-
assert isinstance(backbone, FluxTransformer2DModel)
212+
assert isinstance(backbone, FluxTransformer2DModel) or isinstance(
213+
backbone._orig_mod, FluxTransformer2DModel
214+
)
209215
cfg = backbone.config
210216
text_maxlen = 512
211217
img_dim = 4096
@@ -251,7 +257,9 @@ def _gen_dummy_inp_and_dyn_shapes_flux(backbone, min_bs=1, opt_bs=1):
251257

252258

253259
def _gen_dummy_inp_and_dyn_shapes_ltx(backbone, min_bs=2, opt_bs=2):
254-
assert isinstance(backbone, LTXVideoTransformer3DModel)
260+
assert isinstance(backbone, LTXVideoTransformer3DModel) or isinstance(
261+
backbone._orig_mod, LTXVideoTransformer3DModel
262+
)
255263
cfg = backbone.config
256264
dtype = backbone.dtype
257265
video_dim = 2240

tests/examples/diffusers/test_diffusers.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,41 @@ def test_diffusers_quantization(
150150
model.quantize(tmp_path)
151151
model.restore(tmp_path)
152152
model.inference(tmp_path)
153+
154+
155+
@pytest.mark.parametrize(
156+
("model_name", "model_path", "torch_compile"),
157+
[
158+
("flux-schnell", FLUX_SCHNELL_PATH, False),
159+
("flux-schnell", FLUX_SCHNELL_PATH, True),
160+
("sd3-medium", SD3_PATH, False),
161+
("sd3-medium", SD3_PATH, True),
162+
("sdxl-1.0", SDXL_1_0_PATH, False),
163+
("sdxl-1.0", SDXL_1_0_PATH, True),
164+
],
165+
ids=[
166+
"flux_schnell_torch",
167+
"flux_schnell_torch_compile",
168+
"sd3_medium_torch",
169+
"sd3_medium_torch_compile",
170+
"sdxl_1.0_torch",
171+
"sdxl_1.0_torch_compile",
172+
],
173+
)
174+
def test_diffusion_trt_torch(
175+
model_name: str,
176+
model_path: str,
177+
torch_compile: bool,
178+
) -> None:
179+
cmd_args = [
180+
"python",
181+
"diffusion_trt.py",
182+
"--model",
183+
model_name,
184+
"--override-model-path",
185+
model_path,
186+
"--torch",
187+
]
188+
if torch_compile:
189+
cmd_args.append("--torch-compile")
190+
run_example_command(cmd_args, "diffusers/quantization")

0 commit comments

Comments
 (0)