Skip to content

Commit 4d518ab

Browse files
authored
Merge branch 'main' into deprecate-slicing-tiling-pipe
2 parents 96615d8 + 28106fc commit 4d518ab

File tree

2 files changed

+86
-13
lines changed

2 files changed

+86
-13
lines changed

examples/dreambooth/train_dreambooth_lora_flux_kontext.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
import numpy as np
3030
import torch
3131
import transformers
32-
from accelerate import Accelerator
32+
from accelerate import Accelerator, DistributedType
3333
from accelerate.logging import get_logger
34+
from accelerate.state import AcceleratorState
3435
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
3536
from huggingface_hub import create_repo, upload_folder
3637
from huggingface_hub.utils import insecure_hashlib
@@ -1222,6 +1223,9 @@ def main(args):
12221223
kwargs_handlers=[kwargs],
12231224
)
12241225

1226+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
1227+
AcceleratorState().deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
1228+
12251229
# Disable AMP for MPS.
12261230
if torch.backends.mps.is_available():
12271231
accelerator.native_amp = False
@@ -1438,17 +1442,20 @@ def save_model_hook(models, weights, output_dir):
14381442
text_encoder_one_lora_layers_to_save = None
14391443
modules_to_save = {}
14401444
for model in models:
1441-
if isinstance(model, type(unwrap_model(transformer))):
1445+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
1446+
model = unwrap_model(model)
14421447
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
14431448
modules_to_save["transformer"] = model
1444-
elif isinstance(model, type(unwrap_model(text_encoder_one))):
1449+
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
1450+
model = unwrap_model(model)
14451451
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
14461452
modules_to_save["text_encoder"] = model
14471453
else:
14481454
raise ValueError(f"unexpected save model: {model.__class__}")
14491455

14501456
# make sure to pop weight so that corresponding model is not saved again
1451-
weights.pop()
1457+
if weights:
1458+
weights.pop()
14521459

14531460
FluxKontextPipeline.save_lora_weights(
14541461
output_dir,
@@ -1461,15 +1468,25 @@ def load_model_hook(models, input_dir):
14611468
transformer_ = None
14621469
text_encoder_one_ = None
14631470

1464-
while len(models) > 0:
1465-
model = models.pop()
1471+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
1472+
while len(models) > 0:
1473+
model = models.pop()
14661474

1467-
if isinstance(model, type(unwrap_model(transformer))):
1468-
transformer_ = model
1469-
elif isinstance(model, type(unwrap_model(text_encoder_one))):
1470-
text_encoder_one_ = model
1471-
else:
1472-
raise ValueError(f"unexpected save model: {model.__class__}")
1475+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
1476+
transformer_ = unwrap_model(model)
1477+
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
1478+
text_encoder_one_ = unwrap_model(model)
1479+
else:
1480+
raise ValueError(f"unexpected save model: {model.__class__}")
1481+
1482+
else:
1483+
transformer_ = FluxTransformer2DModel.from_pretrained(
1484+
args.pretrained_model_name_or_path, subfolder="transformer"
1485+
)
1486+
transformer_.add_adapter(transformer_lora_config)
1487+
text_encoder_one_ = text_encoder_cls_one.from_pretrained(
1488+
args.pretrained_model_name_or_path, subfolder="text_encoder"
1489+
)
14731490

14741491
lora_state_dict = FluxKontextPipeline.lora_state_dict(input_dir)
14751492

@@ -2069,7 +2086,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20692086
progress_bar.update(1)
20702087
global_step += 1
20712088

2072-
if accelerator.is_main_process:
2089+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
20732090
if global_step % args.checkpointing_steps == 0:
20742091
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
20752092
if args.checkpoints_total_limit is not None:

src/diffusers/quantizers/gguf/utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,64 @@ def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
429429
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
430430

431431

432+
# this part from calcuis (gguf.org)
433+
# more info: https://github.com/calcuis/gguf-connector/blob/main/src/gguf_connector/quant2c.py
434+
435+
436+
def dequantize_blocks_IQ4_NL(blocks, block_size, type_size, dtype=None):
437+
kvalues = torch.tensor(
438+
[-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113],
439+
dtype=torch.float32,
440+
device=blocks.device,
441+
)
442+
n_blocks = blocks.shape[0]
443+
d, qs = split_block_dims(blocks, 2)
444+
d = d.view(torch.float16).to(dtype)
445+
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
446+
[0, 4], device=blocks.device, dtype=torch.uint8
447+
).reshape((1, 1, 2, 1))
448+
qs = (qs & 15).reshape((n_blocks, -1)).to(torch.int64)
449+
kvalues = kvalues.view(1, 1, 16)
450+
qs = qs.unsqueeze(-1)
451+
qs = torch.gather(kvalues.expand(qs.shape[0], qs.shape[1], 16), 2, qs)
452+
qs = qs.squeeze(-1).to(dtype)
453+
return d * qs
454+
455+
456+
def dequantize_blocks_IQ4_XS(blocks, block_size, type_size, dtype=None):
457+
kvalues = torch.tensor(
458+
[-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113],
459+
dtype=torch.float32,
460+
device=blocks.device,
461+
)
462+
n_blocks = blocks.shape[0]
463+
d, scales_h, scales_l, qs = split_block_dims(blocks, 2, 2, QK_K // 64)
464+
d = d.view(torch.float16).to(dtype)
465+
scales_h = scales_h.view(torch.int16)
466+
scales_l = scales_l.reshape((n_blocks, -1, 1)) >> torch.tensor(
467+
[0, 4], device=blocks.device, dtype=torch.uint8
468+
).reshape((1, 1, 2))
469+
scales_h = scales_h.reshape((n_blocks, 1, -1)) >> torch.tensor(
470+
[2 * i for i in range(QK_K // 32)], device=blocks.device, dtype=torch.uint8
471+
).reshape((1, -1, 1))
472+
scales_l = scales_l.reshape((n_blocks, -1)) & 0x0F
473+
scales_h = scales_h.reshape((n_blocks, -1)) & 0x03
474+
scales = (scales_l | (scales_h << 4)) - 32
475+
dl = (d * scales.to(dtype)).reshape((n_blocks, -1, 1))
476+
shifts_q = torch.tensor([0, 4], device=blocks.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
477+
qs = qs.reshape((n_blocks, -1, 1, 16)) >> shifts_q
478+
qs = (qs & 15).reshape((n_blocks, -1, 32)).to(torch.int64)
479+
kvalues = kvalues.view(1, 1, 1, 16)
480+
qs = qs.unsqueeze(-1)
481+
qs = torch.gather(kvalues.expand(qs.shape[0], qs.shape[1], qs.shape[2], 16), 3, qs)
482+
qs = qs.squeeze(-1).to(dtype)
483+
return (dl * qs).reshape(n_blocks, -1)
484+
485+
432486
GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES
433487
dequantize_functions = {
488+
gguf.GGMLQuantizationType.IQ4_NL: dequantize_blocks_IQ4_NL,
489+
gguf.GGMLQuantizationType.IQ4_XS: dequantize_blocks_IQ4_XS,
434490
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
435491
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
436492
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,

0 commit comments

Comments
 (0)