Skip to content

Commit 7e38eed

Browse files
authored
Merge pull request #1386 from d8ahazard/dev
1.1.0 - Shake Them Haters Off
2 parents d8e01e1 + e27f707 commit 7e38eed

File tree

13 files changed

+2629
-2440
lines changed

13 files changed

+2629
-2440
lines changed

dreambooth/diff_to_sd.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from dreambooth.shared import status
2121
from dreambooth.utils.model_utils import unload_system_models, \
2222
reload_system_models, \
23-
disable_safe_unpickle, enable_safe_unpickle, import_model_class_from_model_name_or_path
23+
safe_unpickle_disabled, import_model_class_from_model_name_or_path
2424
from dreambooth.utils.utils import printi
2525
from helpers.mytqdm import mytqdm
2626
from lora_diffusion.lora import merge_lora_to_model
@@ -562,9 +562,8 @@ def load_model(model_path: str, map_location: str):
562562
if ".safetensors" in model_path:
563563
return safetensors.torch.load_file(model_path, device=map_location)
564564
else:
565-
disable_safe_unpickle()
566-
loaded = torch.load(model_path, map_location=map_location)
567-
enable_safe_unpickle()
565+
with safe_unpickle_disabled():
566+
loaded = torch.load(model_path, map_location=map_location)
568567
return loaded
569568

570569

dreambooth/sd_to_diff.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from dreambooth import shared
2727
from dreambooth.dataclasses.db_config import DreamboothConfig
28-
from dreambooth.utils.model_utils import enable_safe_unpickle, disable_safe_unpickle, unload_system_models, \
28+
from dreambooth.utils.model_utils import safe_unpickle_disabled, unload_system_models, \
2929
reload_system_models
3030

3131

@@ -131,7 +131,6 @@ def extract_checkpoint(
131131
# sh.update_status(status)
132132
# else:
133133
# modules.shared.status.update(status)
134-
disable_safe_unpickle()
135134
if image_size is None:
136135
image_size = 512
137136
if model_type == "v2x":
@@ -162,59 +161,60 @@ def extract_checkpoint(
162161
db_config.resolution = image_size
163162
db_config.save()
164163
try:
165-
if from_safetensors:
166-
if model_type == "SDXL":
167-
pipe = StableDiffusionXLPipeline.from_single_file(
168-
pretrained_model_link_or_path=checkpoint_file,
164+
with safe_unpickle_disabled():
165+
if from_safetensors:
166+
if model_type == "SDXL":
167+
pipe = StableDiffusionXLPipeline.from_single_file(
168+
pretrained_model_link_or_path=checkpoint_file,
169+
)
170+
else:
171+
pipe = StableDiffusionPipeline.from_single_file(
172+
pretrained_model_link_or_path=checkpoint_file,
173+
)
174+
elif model_type == "SDXL":
175+
pipe = StableDiffusionXLPipeline.from_pretrained(
176+
checkpoint_path_or_dict=checkpoint_file,
177+
original_config_file=original_config_file,
178+
image_size=image_size,
179+
prediction_type=prediction_type,
180+
model_type=pipeline_type,
181+
extract_ema=extract_ema,
182+
scheduler_type=scheduler_type,
183+
num_in_channels=num_in_channels,
184+
upcast_attention=upcast_attention,
185+
from_safetensors=from_safetensors,
186+
device=device,
187+
pretrained_model_name_or_path=checkpoint_file,
188+
stable_unclip=stable_unclip,
189+
stable_unclip_prior=stable_unclip_prior,
190+
clip_stats_path=clip_stats_path,
191+
controlnet=controlnet,
192+
vae_path=vae_path,
193+
pipeline_class=pipeline_class,
194+
half=half
169195
)
170196
else:
171-
pipe = StableDiffusionPipeline.from_single_file(
172-
pretrained_model_link_or_path=checkpoint_file,
197+
pipe = StableDiffusionPipeline.from_pretrained(
198+
checkpoint_path_or_dict=checkpoint_file,
199+
original_config_file=original_config_file,
200+
image_size=image_size,
201+
prediction_type=prediction_type,
202+
model_type=pipeline_type,
203+
extract_ema=extract_ema,
204+
scheduler_type=scheduler_type,
205+
num_in_channels=num_in_channels,
206+
upcast_attention=upcast_attention,
207+
from_safetensors=from_safetensors,
208+
device=device,
209+
pretrained_model_name_or_path=checkpoint_file,
210+
stable_unclip=stable_unclip,
211+
stable_unclip_prior=stable_unclip_prior,
212+
clip_stats_path=clip_stats_path,
213+
controlnet=controlnet,
214+
vae_path=vae_path,
215+
pipeline_class=pipeline_class,
216+
half=half
173217
)
174-
elif model_type == "SDXL":
175-
pipe = StableDiffusionXLPipeline.from_pretrained(
176-
checkpoint_path_or_dict=checkpoint_file,
177-
original_config_file=original_config_file,
178-
image_size=image_size,
179-
prediction_type=prediction_type,
180-
model_type=pipeline_type,
181-
extract_ema=extract_ema,
182-
scheduler_type=scheduler_type,
183-
num_in_channels=num_in_channels,
184-
upcast_attention=upcast_attention,
185-
from_safetensors=from_safetensors,
186-
device=device,
187-
pretrained_model_name_or_path=checkpoint_file,
188-
stable_unclip=stable_unclip,
189-
stable_unclip_prior=stable_unclip_prior,
190-
clip_stats_path=clip_stats_path,
191-
controlnet=controlnet,
192-
vae_path=vae_path,
193-
pipeline_class=pipeline_class,
194-
half=half
195-
)
196-
else:
197-
pipe = StableDiffusionPipeline.from_pretrained(
198-
checkpoint_path_or_dict=checkpoint_file,
199-
original_config_file=original_config_file,
200-
image_size=image_size,
201-
prediction_type=prediction_type,
202-
model_type=pipeline_type,
203-
extract_ema=extract_ema,
204-
scheduler_type=scheduler_type,
205-
num_in_channels=num_in_channels,
206-
upcast_attention=upcast_attention,
207-
from_safetensors=from_safetensors,
208-
device=device,
209-
pretrained_model_name_or_path=checkpoint_file,
210-
stable_unclip=stable_unclip,
211-
stable_unclip_prior=stable_unclip_prior,
212-
clip_stats_path=clip_stats_path,
213-
controlnet=controlnet,
214-
vae_path=vae_path,
215-
pipeline_class=pipeline_class,
216-
half=half
217-
)
218218

219219
dump_path = db_config.get_pretrained_model_name_or_path()
220220
if controlnet:
@@ -246,7 +246,7 @@ def extract_checkpoint(
246246
print(f"Couldn't find {full_path}")
247247
break
248248
remove_dirs = ["logging", "samples"]
249-
enable_safe_unpickle()
249+
250250
reload_system_models()
251251
if success:
252252
for rd in remove_dirs:

0 commit comments

Comments
 (0)