Skip to content

Commit 5f7d9b8

Browse files
authored
Merge pull request #1025 from d8ahazard/dev
dev to main [release 0.12.0]
2 parents 19d27b6 + 6a0b22b commit 5f7d9b8

File tree

17 files changed

+309
-165
lines changed

17 files changed

+309
-165
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ they'll help me help you faster.
347347

348348
[Feature Request](https://github.com/d8ahazard/sd_dreambooth_extension/issues/new?assignees=&labels=&template=feature_request.md&title=)
349349

350+
[Discord](https://discord.gg/q8dtpfRD5w)
351+
350352
# Credits
351353

352354
[Huggingface.co](https://huggingface.co) - All the things

dreambooth/dataclasses/db_config.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
from pydantic import BaseModel
77

8-
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names
9-
108
try:
119
from extensions.sd_dreambooth_extension.dreambooth import shared
1210
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
11+
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names
12+
1313
except:
1414
from dreambooth.dreambooth import shared # noqa
1515
from dreambooth.dreambooth.dataclasses.db_concept import Concept # noqa
16+
from dreambooth.dreambooth.utils.image_utils import get_scheduler_names # noqa
1617

1718
# Keys to save, replacing our dumb __init__ method
1819
save_keys = []
@@ -44,7 +45,6 @@ class DreamboothConfig(BaseModel):
4445
gradient_checkpointing: bool = True
4546
gradient_set_to_none: bool = True
4647
graph_smoothing: int = 50
47-
half_lora: bool = False
4848
half_model: bool = False
4949
train_unfrozen: bool = True
5050
has_ema: bool = False
@@ -95,6 +95,7 @@ class DreamboothConfig(BaseModel):
9595
save_lora_after: bool = True
9696
save_lora_cancel: bool = False
9797
save_lora_during: bool = True
98+
save_lora_for_extra_net: bool = True
9899
save_preview_every: int = 5
99100
save_safetensors: bool = True
100101
save_state_after: bool = False
@@ -146,14 +147,6 @@ def __init__(self, model_name: str = "", v2: bool = False, src: str = "",
146147
self.scheduler = "ddim"
147148
self.v2 = v2
148149

149-
# Naive fixes for bad types
150-
if not isinstance(self.lora_model_name, str):
151-
print("Bad lora_model_name found, setting to ''")
152-
self.lora_model_name = ''
153-
if not isinstance(self.stop_text_encoder, float):
154-
print("Bad stop_text_encoder found, setting to 0.0")
155-
self.stop_text_encoder = 0.0
156-
157150
# Actually save as a file
158151
def save(self, backup=False):
159152
"""

dreambooth/dataset/class_dataset.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,12 @@ def __init__(self, concepts: [Concept], model_dir: str, max_width: int, shuffle:
3232
# Data for new prompts to generate
3333
self.new_prompts = {}
3434
self.required_prompts = 0
35-
# Calculate minimum width
36-
min_width = (int(max_width * 0.28125) // 64) * 64
3735

3836
# Thingy to build prompts
3937
text_getter = FilenameTextGetter(shuffle)
4038

4139
# Create available resolutions
42-
bucket_resos = make_bucket_resolutions(max_width, min_width)
40+
bucket_resos = make_bucket_resolutions(max_width)
4341
c_idx = 0
4442
c_images = {}
4543
i_images = {}

dreambooth/dataset/db_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,15 @@ def cache_caption(self, image_path, caption):
153153
self.caption_cache[image_path] = input_ids
154154
return caption, input_ids
155155

156-
def make_buckets_with_caching(self, vae, min_size):
156+
def make_buckets_with_caching(self, vae):
157157
self.vae = vae
158158
self.cache_latents = vae is not None
159159
state = f"Preparing Dataset ({'With Caching' if self.cache_latents else 'Without Caching'})"
160160
print(state)
161161
status.textinfo = state
162162

163163
# Create a list of resolutions
164-
bucket_resos = make_bucket_resolutions(self.resolution, min_size)
164+
bucket_resos = make_bucket_resolutions(self.resolution)
165165
self.train_dict = {}
166166

167167
def sort_images(img_data: List[PromptData], resos, target_dict, is_class_img):

dreambooth/diff_to_sd.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
try:
1818
from extensions.sd_dreambooth_extension.dreambooth import shared as shared
19-
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file
19+
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig
2020
from extensions.sd_dreambooth_extension.dreambooth.shared import status
2121
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, \
2222
reload_system_models, \
@@ -26,7 +26,7 @@
2626
from extensions.sd_dreambooth_extension.lora_diffusion.lora import merge_lora_to_model
2727
except:
2828
from dreambooth.dreambooth import shared as shared # noqa
29-
from dreambooth.dreambooth.dataclasses.db_config import from_file # noqa
29+
from dreambooth.dreambooth.dataclasses.db_config import from_file, DreamboothConfig # noqa
3030
from dreambooth.dreambooth.shared import status # noqa
3131
from dreambooth.dreambooth.utils.model_utils import unload_system_models, reload_system_models, \
3232
disable_safe_unpickle, enable_safe_unpickle, import_model_class_from_model_name_or_path # noqa
@@ -338,13 +338,13 @@ def get_model_path(working_dir: str, model_name: str = "", file_extra: str = "")
338338
return None
339339

340340

341-
def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bool = True, log: bool = True,
341+
def compile_checkpoint(model_name: str, lora_file_name: str = None, reload_models: bool = True, log: bool = True,
342342
snap_rev: str = ""):
343343
"""
344344
345345
@param model_name: The model name to compile
346346
@param reload_models: Whether to reload the system list of checkpoints.
347-
@param lora_path: The path to a lora pt file to merge with the unet. Auto set during training.
347+
@param lora_file_name: The path to a lora pt file to merge with the unet. Auto set during training.
348348
@param log: Whether to print messages to console/UI.
349349
@param snap_rev: The revision of snapshot to load from
350350
@return: status: What happened, path: Checkpoint path
@@ -355,8 +355,8 @@ def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bo
355355
status.job_count = 7
356356

357357
config = from_file(model_name)
358-
if lora_path is None and config.lora_model_name:
359-
lora_path = config.lora_model_name
358+
if lora_file_name is None and config.lora_model_name:
359+
lora_file_name = config.lora_model_name
360360
save_model_name = model_name if config.custom_model_name == "" else config.custom_model_name
361361
if config.custom_model_name == "":
362362
printi(f"Compiling checkpoint for {model_name}...", log=log)
@@ -418,10 +418,9 @@ def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bo
418418
pass
419419

420420
# Apply LoRA to the unet
421-
if lora_path is not None and lora_path != "":
421+
if lora_file_name is not None and lora_file_name != "":
422422
unet_model = UNet2DConditionModel().from_pretrained(os.path.dirname(unet_path))
423-
lora_rev = apply_lora(unet_model, lora_path, config.lora_unet_rank, config.lora_weight, "cpu", False,
424-
config.use_lora_extended)
423+
lora_rev = apply_lora(config, unet_model, lora_file_name, "cpu", False)
425424
unet_state_dict = copy.deepcopy(unet_model.state_dict())
426425
del unet_model
427426
if lora_rev is not None:
@@ -448,9 +447,9 @@ def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bo
448447
printi("Converting text encoder...", log=log)
449448

450449
# Apply lora weights to the tenc
451-
if lora_path is not None and lora_path != "":
452-
lora_paths = lora_path.split(".")
453-
lora_txt_path = f"{lora_paths[0]}_txt.{lora_paths[1]}"
450+
if lora_file_name is not None and lora_file_name != "":
451+
lora_paths = lora_file_name.split(".")
452+
lora_txt_file_name = f"{lora_paths[0]}_txt.{lora_paths[1]}"
454453
text_encoder_cls = import_model_class_from_model_name_or_path(config.pretrained_model_name_or_path,
455454
config.revision)
456455

@@ -461,8 +460,7 @@ def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bo
461460
torch_dtype=torch.float32
462461
)
463462

464-
apply_lora(text_encoder, lora_txt_path, config.lora_txt_rank, config.lora_txt_weight, "cpu", True,
465-
config.use_lora_extended)
463+
apply_lora(config, text_encoder, lora_txt_file_name, "cpu", True)
466464
text_enc_dict = copy.deepcopy(text_encoder.state_dict())
467465
del text_encoder
468466
else:
@@ -551,20 +549,15 @@ def load_model(model_path: str, map_location: str):
551549
return loaded
552550

553551

554-
def apply_lora(model: nn.Module, loras: str, rank: int, weight: float, device: str, is_tenc: bool, use_extended: bool):
552+
def apply_lora(config: DreamboothConfig, model: nn.Module, lora_file_name: str, device: str, is_tenc: bool):
555553
lora_rev = None
556-
if loras is not None and loras != "":
557-
if not os.path.exists(loras):
558-
try:
559-
cmd_lora_models_path = shared.lora_models_path
560-
except:
561-
cmd_lora_models_path = None
562-
model_dir = os.path.dirname(cmd_lora_models_path) if cmd_lora_models_path else shared.models_path
563-
loras = os.path.join(model_dir, "lora", loras)
564-
565-
if os.path.exists(loras):
566-
lora_rev = loras.split("_")[-1].replace(".pt", "")
567-
printi(f"Loading lora from {loras}", log=True)
568-
merge_lora_to_model(model, load_model(loras, device), is_tenc, use_extended, rank, weight)
554+
if lora_file_name is not None and lora_file_name != "":
555+
if not os.path.exists(lora_file_name):
556+
lora_file_name = os.path.join(config.model_dir, "loras", lora_file_name)
557+
if os.path.exists(lora_file_name):
558+
lora_rev = lora_file_name.split("_")[-1].replace(".pt", "")
559+
printi(f"Loading lora from {lora_file_name}", log=True)
560+
merge_lora_to_model(model, load_model(lora_file_name, device), is_tenc, config.use_lora_extended,
561+
config.lora_unet_rank, config.lora_weight)
569562

570563
return lora_rev

dreambooth/sd_to_diff.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from huggingface_hub import HfApi, hf_hub_download
2929
from omegaconf import OmegaConf
3030

31-
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_class
3231

3332
try:
3433
from extensions.sd_dreambooth_extension.dreambooth import shared
@@ -37,12 +36,15 @@
3736
enable_safe_unpickle
3837
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
3938
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
39+
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_class
40+
4041
except:
4142
from dreambooth.dreambooth import shared # noqa
4243
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
4344
from dreambooth.dreambooth.utils.model_utils import get_db_models, disable_safe_unpickle, enable_safe_unpickle # noqa
4445
from dreambooth.dreambooth.utils.utils import printi # noqa
4546
from dreambooth.helpers.mytqdm import mytqdm # noqa
47+
from dreambooth.dreambooth.utils.image_utils import get_scheduler_class # noqa
4648

4749
from diffusers import (
4850
AutoencoderKL,

dreambooth/shared.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def load_auto_settings():
1717
global models_path, script_path, ckpt_dir, device_id, disable_safe_unpickle, dataset_filename_word_regex, \
1818
dataset_filename_join_string, show_progress_every_n_steps, parallel_processing_allowed, state, ckptfix, medvram, \
19-
lowvram, dreambooth_models_path, lora_models_path, CLIP_stop_at_last_layers, profile_db, debug, config, device, \
19+
lowvram, dreambooth_models_path, ui_lora_models_path, CLIP_stop_at_last_layers, profile_db, debug, config, device, \
2020
force_cpu, embeddings_dir, sd_model
2121
try:
2222
import modules.script_callbacks
@@ -51,7 +51,7 @@ def set_model(new_model):
5151

5252
try:
5353
dreambooth_models_path = ws.cmd_opts.dreambooth_models_path or dreambooth_models_path
54-
lora_models_path = ws.cmd_opts.lora_models_path or lora_models_path
54+
ui_lora_models_path = ws.cmd_opts.lora_models_path or ui_lora_models_path
5555
embeddings_dir = ws.cmd_opts.embeddings_dir or embeddings_dir
5656
except:
5757
pass
@@ -293,7 +293,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
293293
embeddings_dir = os.path.join(script_path, "embeddings")
294294
dreambooth_models_path = os.path.join(models_path, "dreambooth")
295295
ckpt_dir = os.path.join(models_path, "Stable-diffusion")
296-
lora_models_path = os.path.join(models_path, "lora")
296+
ui_lora_models_path = os.path.join(models_path, "lora")
297297
db_model_config = None
298298
data_path = os.path.join(script_path, ".cache")
299299
show_progress_every_n_steps = 10

dreambooth/train_dreambooth.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from extensions.sd_dreambooth_extension.dreambooth.xattention import optim_to
4646
from extensions.sd_dreambooth_extension.helpers.ema_model import EMAModel
4747
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
48+
from extensions.sd_dreambooth_extension.lora_diffusion.extra_networks import save_extra_networks
4849
from extensions.sd_dreambooth_extension.lora_diffusion.lora import save_lora_weight, \
4950
TEXT_ENCODER_DEFAULT_TARGET_REPLACE, get_target_module
5051
from extensions.sd_dreambooth_extension.dreambooth.deis_velocity import get_velocity
@@ -67,6 +68,7 @@
6768
from dreambooth.dreambooth.xattention import optim_to # noqa
6869
from dreambooth.helpers.ema_model import EMAModel # noqa
6970
from dreambooth.helpers.mytqdm import mytqdm # noqa
71+
from dreambooth.lora_diffusion.extra_networks import save_extra_networks # noqa
7072
from dreambooth.lora_diffusion.lora import save_lora_weight, TEXT_ENCODER_DEFAULT_TARGET_REPLACE, get_target_module # noqa
7173

7274
logger = logging.getLogger(__name__)
@@ -241,21 +243,12 @@ def create_vae():
241243
vae = create_vae()
242244
printm("Created vae")
243245

244-
try:
245-
unet = UNet2DConditionModel.from_pretrained(
246-
args.pretrained_model_name_or_path,
247-
subfolder="unet",
248-
revision=args.revision,
249-
torch_dtype=torch.float32
250-
)
251-
except:
252-
unet = UNet2DConditionModel.from_pretrained(
253-
args.pretrained_model_name_or_path,
254-
subfolder="unet",
255-
revision=args.revision,
256-
torch_dtype=torch.float16
257-
)
258-
unet = unet.to(dtype=torch.float32)
246+
unet = UNet2DConditionModel.from_pretrained(
247+
args.pretrained_model_name_or_path,
248+
subfolder="unet",
249+
revision=args.revision,
250+
torch_dtype=torch.float32
251+
)
259252
unet = torch2ify(unet)
260253

261254
# Check that all trainable models are in full precision
@@ -329,7 +322,7 @@ def create_vae():
329322
if args.use_lora:
330323
unet.requires_grad_(False)
331324
if args.lora_model_name:
332-
lora_path = os.path.join(shared.models_path, "lora", args.lora_model_name)
325+
lora_path = os.path.join(args.model_dir, "loras", args.lora_model_name)
333326
lora_txt = lora_path.replace(".pt", "_txt.pt")
334327

335328
if not os.path.exists(lora_path) or not os.path.isfile(lora_path):
@@ -716,8 +709,8 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
716709
requires_safety_checker=None
717710
)
718711
scheduler_class = get_scheduler_class(args.scheduler)
719-
s_pipeline.unet = torch2ify(s_pipeline.unet)
720712
s_pipeline.enable_attention_slicing()
713+
s_pipeline.unet = torch2ify(s_pipeline.unet)
721714
xformerify(s_pipeline)
722715

723716
s_pipeline.scheduler = scheduler_class.from_config(s_pipeline.scheduler.config)
@@ -735,6 +728,7 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
735728
pbar.update()
736729
try:
737730
out_file = None
731+
# Loras resume from pt
738732
if not args.use_lora:
739733
if save_snapshot:
740734
pbar.set_description("Saving Snapshot")
@@ -756,27 +750,43 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
756750

757751
elif save_lora:
758752
pbar.set_description("Saving Lora Weights...")
753+
# setup directory
754+
loras_dir = os.path.join(args.model_dir, "loras")
755+
os.makedirs(loras_dir, exist_ok=True)
756+
# setup pt path
759757
lora_model_name = args.model_name if args.custom_model_name == "" else args.custom_model_name
760-
model_dir = os.path.dirname(shared.lora_models_path)
761-
out_file = os.path.join(model_dir, "lora")
762-
os.makedirs(out_file, exist_ok=True)
763-
out_file = os.path.join(out_file, f"{lora_model_name}_{args.revision}.pt")
758+
lora_file_prefix = f"{lora_model_name}_{args.revision}"
759+
out_file = os.path.join(loras_dir, f"{lora_file_prefix}.pt")
760+
# create pt
764761
tgt_module = get_target_module("module", args.use_lora_extended)
765-
d_type = torch.float16 if args.half_lora else torch.float32
762+
save_lora_weight(s_pipeline.unet, out_file, tgt_module)
766763

767-
save_lora_weight(s_pipeline.unet, out_file, tgt_module, d_type=d_type)
764+
modelmap = {"unet": (s_pipeline.unet, tgt_module)}
765+
# save text_encoder
768766
if stop_text_percentage != 0:
769767
out_txt = out_file.replace(".pt", "_txt.pt")
768+
modelmap["text_encoder"] = (s_pipeline.text_encoder, TEXT_ENCODER_DEFAULT_TARGET_REPLACE)
770769
save_lora_weight(s_pipeline.text_encoder,
771770
out_txt,
772-
target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
773-
d_type=d_type)
771+
target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE)
774772
pbar.update()
775-
773+
# save extra_net
774+
if args.save_lora_for_extra_net:
775+
if args.use_lora_extended:
776+
import sys
777+
has_locon = len([path for path in sys.path if 'a1111-sd-webui-locon' in path]) > 0
778+
if not has_locon:
779+
raise Exception(r"a1111-sd-webui-locon extension is required to save "
780+
r"extra net for extended lora. Please install "
781+
r"https://github.com/KohakuBlueleaf/a1111-sd-webui-locon")
782+
os.makedirs(shared.ui_lora_models_path, exist_ok=True)
783+
out_safe = os.path.join(shared.ui_lora_models_path, f"{lora_file_prefix}.safetensors")
784+
save_extra_networks(modelmap, out_safe)
785+
# package pt into checkpoint
776786
if save_checkpoint:
777787
pbar.set_description("Compiling Checkpoint")
778788
snap_rev = str(args.revision) if save_snapshot else ""
779-
compile_checkpoint(args.model_name, reload_models=False, lora_path=out_file, log=False,
789+
compile_checkpoint(args.model_name, reload_models=False, lora_file_name=out_file, log=False,
780790
snap_rev=snap_rev)
781791
pbar.update()
782792
printm("Restored, moved to acc.device.")

0 commit comments

Comments
 (0)