Skip to content

Commit 3b88b54

Browse files
committed
python SD sample
1 parent 2ca537d commit 3b88b54

File tree

4 files changed

+414
-0
lines changed

4 files changed

+414
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Stable Diffusion 3 Medium ONNX Export Guide
2+
3+
This guide provides the steps to convert the `stabilityai/stable-diffusion-3-medium` model to the ONNX format for use with the CUDA execution provider. It also includes a step to address an issue with mixed-precision nodes that may occur during the conversion process.
4+
5+
## 1. Prerequisites and Installation
6+
7+
Install the required Python packages using the following `requirements.txt` content:
8+
9+
```
10+
numpy
11+
torch --index-url https://download.pytorch.org/whl/cu121
12+
optimum[onnxruntime]
13+
onnxruntime-gpu
14+
diffusers
15+
sentencepiece
16+
transformers
17+
```
18+
19+
You can save this to a `requirements.txt` file and install it with:
20+
```bash
21+
pip install -r requirements.txt
22+
```
23+
This will install `onnxruntime-gpu` with the CUDA execution provider, which is necessary for model conversion.
24+
25+
## 2. Model Conversion
26+
27+
Run the following command to export the model to ONNX format. This command uses `optimum-cli` to convert the model to half-precision (`fp16`) on a CUDA device.
28+
29+
```bash
30+
optimum-cli export onnx --model stabilityai/stable-diffusion-3-medium --dtype fp16 --device cuda fp16_optimum
31+
```
32+
33+
This will download the model and convert it into multiple ONNX files in the `fp16_optimum` directory.
34+
35+
## 3. Correcting FP64 Nodes
36+
37+
The PyTorch model may contain some `fp64` nodes, which are exported as-is during the conversion. If you encounter issues with these nodes, you can use the provided `Replace_fp64.py` script to replace them with `fp32` nodes. This script will process all `.onnx` files in the input directory and save the corrected files to the output directory.
38+
39+
```bash
40+
python replace_fp64.py fp16_optimum corrected_model
41+
```
42+
This will create a `corrected_model` directory with the FP64 nodes converted to FP32.
43+
44+
## 4. Using a Custom ONNX Runtime
45+
46+
If you have a locally built ONNX Runtime wheel with specific optimizations (e.g., for NvTensorRTRTXExecutionProvider), ensure that you install it in your environment before running inference. Additionally, be sure to uninstall the default `onnxruntime` package installed via `requirements.txt` to avoid any conflicts.
47+
48+
## 5. Running Inference
49+
50+
To run inference with the converted ONNX model, use the provided `RunSd.py` script. This script loads the ONNX model and generates an image based on a prompt.
51+
52+
Here is an example command to run the script:
53+
```bash
54+
python run_sd.py --model_path corrected_model --prompt "A beautiful landscape painting of a waterfall in a lush forest" --output_dir generated_images
55+
```
56+
57+
### Command-line Arguments
58+
59+
The `RunSd.py` script accepts several arguments to customize the image generation process:
60+
61+
* `--model_path`: Path to the directory containing the ONNX models (e.g., `corrected_model`). (Required)
62+
* `--prompt`: The text prompt to generate the image from.
63+
* `--negative_prompt`: The prompt not to guide the image generation.
64+
* `--height`: The height of the generated image (default: 512).
65+
* `--width`: The width of the generated image (default: 512).
66+
* `--steps`: The number of inference steps (default: 50).
67+
* `--guidance_scale`: Guidance scale for the prompt (default: 7.5).
68+
* `--seed`: A seed for reproducibility.
69+
* `--output_dir`: The directory to save the generated images (default: `generated_images`).
70+
* `--execution_provider`: The ONNX Runtime execution provider to use (default: `NvTensorRTRTXExecutionProvider`).
71+
72+
For a full list of arguments, you can run:
73+
```bash
74+
python run_sd.py --help
75+
```
76+
77+
The generated image will be saved in the specified output directory.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import onnx
2+
from onnx import numpy_helper
3+
import numpy as np
4+
import argparse
5+
import os
6+
import shutil
7+
import logging
8+
9+
# Setup logging
10+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11+
12+
def convert_fp64_to_fp32(model_path: str, output_path: str):
13+
"""
14+
Loads an ONNX model, converts all float64 tensors and casts to float32,
15+
and saves the modified model.
16+
"""
17+
logging.info(f"Loading model from: {model_path}")
18+
model = onnx.load(model_path)
19+
20+
# 1. Convert all initializers from float64 to float32
21+
converted_initializers = 0
22+
new_initializers = []
23+
for initializer in model.graph.initializer:
24+
if initializer.data_type == onnx.TensorProto.DOUBLE:
25+
initializer_np = numpy_helper.to_array(initializer)
26+
initializer_fp32 = initializer_np.astype(np.float32)
27+
new_initializer = numpy_helper.from_array(initializer_fp32, name=initializer.name)
28+
new_initializers.append(new_initializer)
29+
converted_initializers += 1
30+
else:
31+
new_initializers.append(initializer)
32+
33+
model.graph.ClearField("initializer")
34+
model.graph.initializer.extend(new_initializers)
35+
36+
if converted_initializers > 0:
37+
logging.info(f"Converted {converted_initializers} initializers from FP64 to FP32.")
38+
39+
# 2. Convert nodes
40+
converted_casts = 0
41+
converted_constants = 0
42+
for node in model.graph.node:
43+
if node.op_type == 'Constant':
44+
for attr in node.attribute:
45+
if attr.name == 'value' and attr.t.data_type == onnx.TensorProto.DOUBLE:
46+
attr.t.data_type = onnx.TensorProto.FLOAT
47+
fp64_array = np.frombuffer(attr.t.raw_data, dtype=np.float64)
48+
fp32_array = fp64_array.astype(np.float32)
49+
attr.t.raw_data = fp32_array.tobytes()
50+
converted_constants += 1
51+
elif node.op_type == 'Cast':
52+
for attr in node.attribute:
53+
if attr.name == 'to' and attr.i == onnx.TensorProto.DOUBLE:
54+
attr.i = onnx.TensorProto.FLOAT
55+
converted_casts += 1
56+
57+
if converted_casts > 0:
58+
logging.info(f"Modified {converted_casts} Cast operators from FP64 to FP32.")
59+
if converted_constants > 0:
60+
logging.info(f"Modified {converted_constants} Constant operators from FP64 to FP32.")
61+
62+
# 3. Convert all graph inputs, outputs, and value_info from float64 to float32
63+
converted_tensors = 0
64+
for tensor in list(model.graph.value_info) + list(model.graph.input) + list(model.graph.output):
65+
if tensor.type.tensor_type.elem_type == onnx.TensorProto.DOUBLE:
66+
tensor.type.tensor_type.elem_type = onnx.TensorProto.FLOAT
67+
converted_tensors += 1
68+
69+
if converted_tensors > 0:
70+
logging.info(f"Converted {converted_tensors} tensor definitions from FP64 to FP32.")
71+
72+
# 4. Save the modified model
73+
logging.info(f"Saving modified model to: {output_path}")
74+
onnx.save(model, output_path, save_as_external_data=True)
75+
logging.info("Conversion complete.")
76+
77+
78+
if __name__ == '__main__':
79+
parser = argparse.ArgumentParser(
80+
description="Convert ONNX models in a directory from float64 to float32 precision."
81+
)
82+
parser.add_argument("input_dir", type=str, help="Directory containing the input ONNX models.")
83+
parser.add_argument("output_dir", type=str, help="Directory where the converted models will be saved.")
84+
args = parser.parse_args()
85+
86+
input_dir = args.input_dir
87+
output_dir = args.output_dir
88+
89+
if not os.path.exists(output_dir):
90+
os.makedirs(output_dir)
91+
logging.info(f"Created output directory: {output_dir}")
92+
93+
for root, _, files in os.walk(input_dir):
94+
# Replicate directory structure in the output directory
95+
relative_path = os.path.relpath(root, input_dir)
96+
output_subdir = os.path.join(output_dir, relative_path)
97+
if not os.path.exists(output_subdir):
98+
os.makedirs(output_subdir)
99+
100+
for filename in files:
101+
input_path = os.path.join(root, filename)
102+
output_path = os.path.join(output_subdir, filename)
103+
104+
if filename.endswith(".onnx"):
105+
logging.info("-" * 50)
106+
logging.info(f"Processing ONNX file: {input_path}")
107+
try:
108+
convert_fp64_to_fp32(input_path, output_path)
109+
except Exception as e:
110+
logging.error(f"Failed to convert {input_path}: {e}")
111+
logging.info("-" * 50)
112+
elif filename.endswith(".onnx_data"):
113+
# Skip copying .onnx_data files as new ones will be created on save
114+
continue
115+
else:
116+
logging.info(f"Copying file: {input_path} to {output_path}")
117+
shutil.copy2(input_path, output_path)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
numpy
2+
torch
3+
--index-url https://download.pytorch.org/whl/cu129
4+
optimum[onnxruntime]
5+
onnxruntime-gpu
6+
diffusers
7+
sentencepiece

0 commit comments

Comments
 (0)