Skip to content

Commit e5dcdec

Browse files
committed
update
1 parent 1dcd24e commit e5dcdec

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

docs/source/en/quantization/torchao.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,37 @@ import torch
113113
from diffusers import FluxPipeline, FluxTransformer2DModel
114114

115115
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
116-
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
116+
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
117117
pipe.to("cuda")
118118

119119
prompt = "A cat holding a sign that says hello world"
120120
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
121121
image.save("output.png")
122-
```
122+
```
123+
124+
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
125+
126+
```python
127+
import torch
128+
from accelerate import init_empty_weights
129+
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
130+
131+
# Serialize the model
132+
transformer = FluxTransformer2DModel.from_pretrained(
133+
"black-forest-labs/Flux.1-Dev",
134+
subfolder="transformer",
135+
quantization_config=TorchAoConfig("uint4wo"),
136+
torch_dtype=torch.bfloat16,
137+
)
138+
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
139+
# ...
140+
141+
# Load the model
142+
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
143+
with init_empty_weights():
144+
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
145+
transformer.load_state_dict(state_dict, strict=True, assign=True)
146+
```
123147

124148
## Resources
125149

0 commit comments

Comments
 (0)