Skip to content

Commit 85c4a32

Browse files
authored
Fix saving text encoder weights and kohya weights in advanced dreambooth lora script (#8766)
* update * update * update
1 parent 0bab9d6 commit 85c4a32

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,7 @@ def save_model_hook(models, weights, output_dir):
12901290
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
12911291
get_peft_model_state_dict(model)
12921292
)
1293+
else:
12931294
raise ValueError(f"unexpected save model: {model.__class__}")
12941295

12951296
# make sure to pop weight so that corresponding model is not saved again
@@ -1981,7 +1982,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19811982
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
19821983
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
19831984
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
1984-
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
1985+
save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
19851986

19861987
save_model_card(
19871988
model_id if not args.push_to_hub else repo_id,

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2425,7 +2425,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
24252425
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
24262426
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
24272427
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
2428-
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
2428+
save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
24292429

24302430
save_model_card(
24312431
model_id if not args.push_to_hub else repo_id,

0 commit comments

Comments
 (0)