|
25 | 25 |
|
26 | 26 | from dreambooth import shared |
27 | 27 | from dreambooth.dataclasses.db_config import DreamboothConfig |
28 | | -from dreambooth.utils.model_utils import enable_safe_unpickle, disable_safe_unpickle, unload_system_models, \ |
| 28 | +from dreambooth.utils.model_utils import safe_unpickle_disabled, unload_system_models, \ |
29 | 29 | reload_system_models |
30 | 30 |
|
31 | 31 |
|
@@ -131,7 +131,6 @@ def extract_checkpoint( |
131 | 131 | # sh.update_status(status) |
132 | 132 | # else: |
133 | 133 | # modules.shared.status.update(status) |
134 | | - disable_safe_unpickle() |
135 | 134 | if image_size is None: |
136 | 135 | image_size = 512 |
137 | 136 | if model_type == "v2x": |
@@ -162,59 +161,60 @@ def extract_checkpoint( |
162 | 161 | db_config.resolution = image_size |
163 | 162 | db_config.save() |
164 | 163 | try: |
165 | | - if from_safetensors: |
166 | | - if model_type == "SDXL": |
167 | | - pipe = StableDiffusionXLPipeline.from_single_file( |
168 | | - pretrained_model_link_or_path=checkpoint_file, |
| 164 | + with safe_unpickle_disabled(): |
| 165 | + if from_safetensors: |
| 166 | + if model_type == "SDXL": |
| 167 | + pipe = StableDiffusionXLPipeline.from_single_file( |
| 168 | + pretrained_model_link_or_path=checkpoint_file, |
| 169 | + ) |
| 170 | + else: |
| 171 | + pipe = StableDiffusionPipeline.from_single_file( |
| 172 | + pretrained_model_link_or_path=checkpoint_file, |
| 173 | + ) |
| 174 | + elif model_type == "SDXL": |
| 175 | + pipe = StableDiffusionXLPipeline.from_pretrained( |
| 176 | + checkpoint_path_or_dict=checkpoint_file, |
| 177 | + original_config_file=original_config_file, |
| 178 | + image_size=image_size, |
| 179 | + prediction_type=prediction_type, |
| 180 | + model_type=pipeline_type, |
| 181 | + extract_ema=extract_ema, |
| 182 | + scheduler_type=scheduler_type, |
| 183 | + num_in_channels=num_in_channels, |
| 184 | + upcast_attention=upcast_attention, |
| 185 | + from_safetensors=from_safetensors, |
| 186 | + device=device, |
| 187 | + pretrained_model_name_or_path=checkpoint_file, |
| 188 | + stable_unclip=stable_unclip, |
| 189 | + stable_unclip_prior=stable_unclip_prior, |
| 190 | + clip_stats_path=clip_stats_path, |
| 191 | + controlnet=controlnet, |
| 192 | + vae_path=vae_path, |
| 193 | + pipeline_class=pipeline_class, |
| 194 | + half=half |
169 | 195 | ) |
170 | 196 | else: |
171 | | - pipe = StableDiffusionPipeline.from_single_file( |
172 | | - pretrained_model_link_or_path=checkpoint_file, |
| 197 | + pipe = StableDiffusionPipeline.from_pretrained( |
| 198 | + checkpoint_path_or_dict=checkpoint_file, |
| 199 | + original_config_file=original_config_file, |
| 200 | + image_size=image_size, |
| 201 | + prediction_type=prediction_type, |
| 202 | + model_type=pipeline_type, |
| 203 | + extract_ema=extract_ema, |
| 204 | + scheduler_type=scheduler_type, |
| 205 | + num_in_channels=num_in_channels, |
| 206 | + upcast_attention=upcast_attention, |
| 207 | + from_safetensors=from_safetensors, |
| 208 | + device=device, |
| 209 | + pretrained_model_name_or_path=checkpoint_file, |
| 210 | + stable_unclip=stable_unclip, |
| 211 | + stable_unclip_prior=stable_unclip_prior, |
| 212 | + clip_stats_path=clip_stats_path, |
| 213 | + controlnet=controlnet, |
| 214 | + vae_path=vae_path, |
| 215 | + pipeline_class=pipeline_class, |
| 216 | + half=half |
173 | 217 | ) |
174 | | - elif model_type == "SDXL": |
175 | | - pipe = StableDiffusionXLPipeline.from_pretrained( |
176 | | - checkpoint_path_or_dict=checkpoint_file, |
177 | | - original_config_file=original_config_file, |
178 | | - image_size=image_size, |
179 | | - prediction_type=prediction_type, |
180 | | - model_type=pipeline_type, |
181 | | - extract_ema=extract_ema, |
182 | | - scheduler_type=scheduler_type, |
183 | | - num_in_channels=num_in_channels, |
184 | | - upcast_attention=upcast_attention, |
185 | | - from_safetensors=from_safetensors, |
186 | | - device=device, |
187 | | - pretrained_model_name_or_path=checkpoint_file, |
188 | | - stable_unclip=stable_unclip, |
189 | | - stable_unclip_prior=stable_unclip_prior, |
190 | | - clip_stats_path=clip_stats_path, |
191 | | - controlnet=controlnet, |
192 | | - vae_path=vae_path, |
193 | | - pipeline_class=pipeline_class, |
194 | | - half=half |
195 | | - ) |
196 | | - else: |
197 | | - pipe = StableDiffusionPipeline.from_pretrained( |
198 | | - checkpoint_path_or_dict=checkpoint_file, |
199 | | - original_config_file=original_config_file, |
200 | | - image_size=image_size, |
201 | | - prediction_type=prediction_type, |
202 | | - model_type=pipeline_type, |
203 | | - extract_ema=extract_ema, |
204 | | - scheduler_type=scheduler_type, |
205 | | - num_in_channels=num_in_channels, |
206 | | - upcast_attention=upcast_attention, |
207 | | - from_safetensors=from_safetensors, |
208 | | - device=device, |
209 | | - pretrained_model_name_or_path=checkpoint_file, |
210 | | - stable_unclip=stable_unclip, |
211 | | - stable_unclip_prior=stable_unclip_prior, |
212 | | - clip_stats_path=clip_stats_path, |
213 | | - controlnet=controlnet, |
214 | | - vae_path=vae_path, |
215 | | - pipeline_class=pipeline_class, |
216 | | - half=half |
217 | | - ) |
218 | 218 |
|
219 | 219 | dump_path = db_config.get_pretrained_model_name_or_path() |
220 | 220 | if controlnet: |
@@ -246,7 +246,7 @@ def extract_checkpoint( |
246 | 246 | print(f"Couldn't find {full_path}") |
247 | 247 | break |
248 | 248 | remove_dirs = ["logging", "samples"] |
249 | | - enable_safe_unpickle() |
| 249 | + |
250 | 250 | reload_system_models() |
251 | 251 | if success: |
252 | 252 | for rd in remove_dirs: |
|
0 commit comments