You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/quantization/bitsandbytes.md
+30-28Lines changed: 30 additions & 28 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -40,16 +40,20 @@ Quantizing a model in 8-bit halves the memory-usage:
40
40
bitsandbytes is supported in both Transformers and Diffusers, so you can quantize both the
41
41
[`FluxTransformer2DModel`] and [`~transformers.T5EncoderModel`].
42
42
43
+
> [!Note]
44
+
> Depending on the GPU, set your `torch_dtype`. For Ada and higher series GPUs support `torch.bfloat16` and we suggest using it when applicable.
45
+
46
+
> [!Note]
47
+
> We do not qunatize the `CLIPTextModel` and the `AutoencoderKL` due to their small size, and also for the fact that `AutoencoderKL` has very few `torch.nn.Linear` layers.
48
+
43
49
```py
44
50
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
45
51
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.
> When memory permits, one can directly mode the pipeline (`pipe` here) to the GPU using the `.to("cuda")` API.
118
+
> One can also use the `enable_model_cpu_offload()` to optimize GPU VRAM usage.
119
+
119
120
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 8-bit models locally with [`~ModelMixin.save_pretrained`].
120
121
121
122
</hfoption>
@@ -126,16 +127,20 @@ Quantizing a model in 4-bit reduces your memory-usage by 4x:
126
127
bitsandbytes is supported in both Transformers and Diffusers, so you can can quantize both the
127
128
[`FluxTransformer2DModel`] and [`~transformers.T5EncoderModel`].
128
129
130
+
> [!Note]
131
+
> Depending on the GPU, set your `torch_dtype`. For Ada and higher series GPUs support `torch.bfloat16` and we suggest using it when applicable.
132
+
133
+
> [!Note]
134
+
> We do not qunatize the `CLIPTextModel` and the `AutoencoderKL` due to their small size, and also for the fact that `AutoencoderKL` has very few `torch.nn.Linear` layers.
135
+
129
136
```py
130
137
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
131
138
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.
> When memory permits, one can directly mode the pipeline (`pipe` here) to the GPU using the `.to("cuda")` API.
204
+
> One can also use the `enable_model_cpu_offload()` to optimize GPU VRAM usage.
205
+
204
206
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
0 commit comments