Skip to content

Commit fc88592

Browse files
authored
Merge branch 'main' into Add-Notebooks
2 parents c0eba1d + f63d322 commit fc88592

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

docs/source/en/using-diffusers/write_own_pipeline.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ Let's try it out!
106106

107107
## Deconstruct the Stable Diffusion pipeline
108108

109-
Stable Diffusion is a text-to-image *latent diffusion* model. It is called a latent diffusion model because it works with a lower-dimensional representation of the image instead of the actual pixel space, which makes it more memory efficient. The encoder compresses the image into a smaller representation, and a decoder to convert the compressed representation back into an image. For text-to-image models, you'll need a tokenizer and an encoder to generate text embeddings. From the previous example, you already know you need a UNet model and a scheduler.
109+
Stable Diffusion is a text-to-image *latent diffusion* model. It is called a latent diffusion model because it works with a lower-dimensional representation of the image instead of the actual pixel space, which makes it more memory efficient. The encoder compresses the image into a smaller representation, and a decoder converts the compressed representation back into an image. For text-to-image models, you'll need a tokenizer and an encoder to generate text embeddings. From the previous example, you already know you need a UNet model and a scheduler.
110110

111111
As you can see, this is already more complex than the DDPM pipeline which only contains a UNet model. The Stable Diffusion model has three separate pretrained models.
112112

examples/text_to_image/train_text_to_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,8 @@ def parse_args():
365365
"--dream_training",
366366
action="store_true",
367367
help=(
368-
"Use the DREAM training method, which makes training more efficient and accurate at the ",
369-
"expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210",
368+
"Use the DREAM training method, which makes training more efficient and accurate at the "
369+
"expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210"
370370
),
371371
)
372372
parser.add_argument(

src/diffusers/quantizers/bitsandbytes/utils.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
153153
return model
154154

155155

156-
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
157-
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
156+
# Adapted from PEFT: https://github.com/huggingface/peft/blob/6d458b300fc2ed82e19f796b53af4c97d03ea604/src/peft/utils/integrations.py#L81
157+
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torch.dtype" = None):
158158
"""
159159
Helper function to dequantize 4bit or 8bit bnb weights.
160160
@@ -177,13 +177,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
177177
if state.SCB is None:
178178
state.SCB = weight.SCB
179179

180-
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
181-
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
182-
im, Sim = bnb.functional.transform(im, "col32")
183-
if state.CxB is None:
184-
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
185-
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
186-
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
180+
if hasattr(bnb.functional, "int8_vectorwise_dequant"):
181+
# Use bitsandbytes API if available (requires v0.45.0+)
182+
dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
183+
else:
184+
# Multiply by (scale/127) to dequantize.
185+
dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
186+
187+
if dtype:
188+
dequantized = dequantized.to(dtype)
189+
return dequantized
187190

188191

189192
def _create_accelerate_new_hook(old_hook):
@@ -205,6 +208,7 @@ def _create_accelerate_new_hook(old_hook):
205208

206209
def _dequantize_and_replace(
207210
model,
211+
dtype,
208212
modules_to_not_convert=None,
209213
current_key_name=None,
210214
quantization_config=None,
@@ -244,7 +248,7 @@ def _dequantize_and_replace(
244248
else:
245249
state = None
246250

247-
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
251+
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state, dtype))
248252

249253
if bias is not None:
250254
new_module.bias = bias
@@ -263,9 +267,10 @@ def _dequantize_and_replace(
263267
if len(list(module.children())) > 0:
264268
_, has_been_replaced = _dequantize_and_replace(
265269
module,
266-
modules_to_not_convert,
267-
current_key_name,
268-
quantization_config,
270+
dtype=dtype,
271+
modules_to_not_convert=modules_to_not_convert,
272+
current_key_name=current_key_name,
273+
quantization_config=quantization_config,
269274
has_been_replaced=has_been_replaced,
270275
)
271276
# Remove the last key for recursion
@@ -280,6 +285,7 @@ def dequantize_and_replace(
280285
):
281286
model, has_been_replaced = _dequantize_and_replace(
282287
model,
288+
dtype=model.dtype,
283289
modules_to_not_convert=modules_to_not_convert,
284290
quantization_config=quantization_config,
285291
)

0 commit comments

Comments
 (0)