Skip to content

Commit 7adf848

Browse files
Update examples for 0.19.0 release
1 parent 1574baa commit 7adf848

File tree

4 files changed

+32
-7
lines changed

4 files changed

+32
-7
lines changed

.gitignore

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ __pycache__/
33
*.py[cod]
44
*$py.class
55

6-
# C extensions
6+
# C, CPP extensions
77
*.so
8+
.rendered.*.cpp
9+
.rendered.*.o
810
# Exclude the plugin file
911
!libfp8convkernel.so
1012

@@ -177,6 +179,6 @@ cython_debug/
177179
**.pb
178180
**.onnx
179181
**.ckpt
180-
181-
# Ignore temporary files created by tox
182-
pyproject.toml.bak
182+
**.safetensors
183+
**.bin
184+
**.pkl

diffusers/quantization/config.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,21 @@ def create_dynamic_shapes(dynamic_shapes):
203203
}
204204

205205

206+
def remove_nesting(compilation_args):
207+
compilation_args["dynamic_shapes"]["minShapes"]["text_embeds"] = compilation_args[
208+
"dynamic_shapes"
209+
]["minShapes"].pop("added_cond_kwargs.text_embeds")
210+
compilation_args["dynamic_shapes"]["minShapes"]["time_ids"] = compilation_args[
211+
"dynamic_shapes"
212+
]["minShapes"].pop("added_cond_kwargs.time_ids")
213+
compilation_args["dynamic_shapes"]["optShapes"]["text_embeds"] = compilation_args[
214+
"dynamic_shapes"
215+
]["optShapes"].pop("added_cond_kwargs.text_embeds")
216+
compilation_args["dynamic_shapes"]["optShapes"]["time_ids"] = compilation_args[
217+
"dynamic_shapes"
218+
]["optShapes"].pop("added_cond_kwargs.time_ids")
219+
220+
206221
DYNAMIC_SHAPES = {
207222
"sdxl-1.0": create_dynamic_shapes(SDXL_DYNAMIC_SHAPES),
208223
"sdxl-turbo": create_dynamic_shapes(SDXL_DYNAMIC_SHAPES),
@@ -236,5 +251,8 @@ def get_io_shapes(model, onnx_load_path):
236251
else:
237252
output_name = "out.0"
238253
io_shapes = IO_SHAPES[model]
239-
io_shapes[output_name] = io_shapes.pop("out.0")
254+
# For models that are loaded from the output name will not be "out.0"
255+
# so we need to update the dictionary key to match the output name
256+
if "out.0" in io_shapes.keys():
257+
io_shapes[output_name] = io_shapes.pop("out.0")
240258
return io_shapes

diffusers/quantization/diffusion_trt.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import argparse
2323

2424
import torch
25-
from config import DYNAMIC_SHAPES, get_io_shapes, update_dynamic_axes
25+
from config import DYNAMIC_SHAPES, get_io_shapes, remove_nesting, update_dynamic_axes
2626
from diffusers import (
2727
DiffusionPipeline,
2828
FluxPipeline,
@@ -111,6 +111,11 @@ def main():
111111

112112
compilation_args = DYNAMIC_SHAPES[args.model]
113113

114+
# We only need to remove the nesting for SDXL models as they contain the nested input added_cond_kwargs
115+
# which are renamed by the DeviceModel
116+
if args.onnx_load_path != "" and args.model in ["sdxl-1.0", "sdxl-turbo"]:
117+
remove_nesting(compilation_args)
118+
114119
# Define deployment configuration
115120
deployment = {
116121
"runtime": "TRT",

llm_ptq/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ Falcon 7B | Yes | Yes | No | No
9494
MPT 7B, 30B | Yes | Yes | Yes | Yes
9595
Baichuan 1, 2 | Yes | Yes | Yes | Yes
9696
ChatGLM2, 3 6B | No | No | Yes | No
97-
Bloom | bloom | Yes | Yes | Yes
97+
Bloom | Yes | Yes | Yes | Yes
9898
Phi-1,2,3 | Yes | Yes | Yes | Yes<sup>4</sup>
9999
Nemotron 8B | Yes | No | Yes | No
100100
Gemma 2B, 7B | Yes | No | Yes | Yes

0 commit comments

Comments
 (0)