Skip to content

Commit 6fd71bd

Browse files
committed
finalize.
1 parent 5d5e80a commit 6fd71bd

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

examples/dreambooth/README_hidream.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,30 @@ We provide several options for optimizing memory optimization:
117117
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
118118

119119
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.
120+
121+
## Using quantization
122+
123+
You can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:
124+
125+
```json
126+
{
127+
"load_in_4bit": true,
128+
"bnb_4bit_quant_type": "nf4"
129+
}
130+
```
131+
132+
Below, we provide some numbers with and without the use of NF4 quantization when training:
133+
134+
```
135+
(with quantization)
136+
Memory (before device placement): 9.085089683532715 GB.
137+
Memory (after device placement): 34.59585428237915 GB.
138+
Memory (after backward): 36.90267467498779 GB.
139+
140+
(without quantization)
141+
Memory (before device placement): 0.0 GB.
142+
Memory (after device placement): 57.6400408744812 GB.
143+
Memory (after backward): 59.932212829589844 GB.
144+
```
145+
146+
The reason why we see some memory before device placement in the case of quantization is because, by default bnb quantized models are placed on the GPU first.

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,10 +1713,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17131713
accelerator.wait_for_everyone()
17141714
if accelerator.is_main_process:
17151715
transformer = unwrap_model(transformer)
1716-
if args.upcast_before_saving:
1717-
transformer.to(torch.float32)
1718-
else:
1719-
transformer = transformer.to(weight_dtype)
1716+
if args.bnb_quantization_config_path is None:
1717+
if args.upcast_before_saving:
1718+
transformer.to(torch.float32)
1719+
else:
1720+
transformer = transformer.to(weight_dtype)
17201721
transformer_lora_layers = get_peft_model_state_dict(transformer)
17211722

17221723
HiDreamImagePipeline.save_lora_weights(

0 commit comments

Comments
 (0)