2323import unittest
2424from typing import List
2525
26- import torch
26+ import safetensors
2727from accelerate .utils import write_basic_config
2828
2929from diffusers import DiffusionPipeline , UNet2DConditionModel
@@ -93,7 +93,7 @@ def test_train_unconditional(self):
9393
9494 run_command (self ._launch_args + test_args , return_stdout = True )
9595 # save_pretrained smoke test
96- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.bin " )))
96+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.safetensors " )))
9797 self .assertTrue (os .path .isfile (os .path .join (tmpdir , "scheduler" , "scheduler_config.json" )))
9898
9999 def test_textual_inversion (self ):
@@ -144,7 +144,7 @@ def test_dreambooth(self):
144144
145145 run_command (self ._launch_args + test_args )
146146 # save_pretrained smoke test
147- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.bin " )))
147+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.safetensors " )))
148148 self .assertTrue (os .path .isfile (os .path .join (tmpdir , "scheduler" , "scheduler_config.json" )))
149149
150150 def test_dreambooth_if (self ):
@@ -170,7 +170,7 @@ def test_dreambooth_if(self):
170170
171171 run_command (self ._launch_args + test_args )
172172 # save_pretrained smoke test
173- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.bin " )))
173+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.safetensors " )))
174174 self .assertTrue (os .path .isfile (os .path .join (tmpdir , "scheduler" , "scheduler_config.json" )))
175175
176176 def test_dreambooth_checkpointing (self ):
@@ -272,10 +272,10 @@ def test_dreambooth_lora(self):
272272
273273 run_command (self ._launch_args + test_args )
274274 # save_pretrained smoke test
275- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
275+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
276276
277277 # make sure the state_dict has the correct naming in the parameters.
278- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
278+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
279279 is_lora = all ("lora" in k for k in lora_state_dict .keys ())
280280 self .assertTrue (is_lora )
281281
@@ -305,10 +305,10 @@ def test_dreambooth_lora_with_text_encoder(self):
305305
306306 run_command (self ._launch_args + test_args )
307307 # save_pretrained smoke test
308- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
308+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
309309
310310 # check `text_encoder` is present at all.
311- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
311+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
312312 keys = lora_state_dict .keys ()
313313 is_text_encoder_present = any (k .startswith ("text_encoder" ) for k in keys )
314314 self .assertTrue (is_text_encoder_present )
@@ -341,10 +341,10 @@ def test_dreambooth_lora_if_model(self):
341341
342342 run_command (self ._launch_args + test_args )
343343 # save_pretrained smoke test
344- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
344+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
345345
346346 # make sure the state_dict has the correct naming in the parameters.
347- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
347+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
348348 is_lora = all ("lora" in k for k in lora_state_dict .keys ())
349349 self .assertTrue (is_lora )
350350
@@ -373,10 +373,10 @@ def test_dreambooth_lora_sdxl(self):
373373
374374 run_command (self ._launch_args + test_args )
375375 # save_pretrained smoke test
376- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
376+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
377377
378378 # make sure the state_dict has the correct naming in the parameters.
379- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
379+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
380380 is_lora = all ("lora" in k for k in lora_state_dict .keys ())
381381 self .assertTrue (is_lora )
382382
@@ -406,10 +406,10 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self):
406406
407407 run_command (self ._launch_args + test_args )
408408 # save_pretrained smoke test
409- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
409+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
410410
411411 # make sure the state_dict has the correct naming in the parameters.
412- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
412+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
413413 is_lora = all ("lora" in k for k in lora_state_dict .keys ())
414414 self .assertTrue (is_lora )
415415
@@ -437,6 +437,7 @@ def test_custom_diffusion(self):
437437 --lr_scheduler constant
438438 --lr_warmup_steps 0
439439 --modifier_token <new1>
440+ --no_safe_serialization
440441 --output_dir { tmpdir }
441442 """ .split ()
442443
@@ -466,7 +467,7 @@ def test_text_to_image(self):
466467
467468 run_command (self ._launch_args + test_args )
468469 # save_pretrained smoke test
469- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.bin " )))
470+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.safetensors " )))
470471 self .assertTrue (os .path .isfile (os .path .join (tmpdir , "scheduler" , "scheduler_config.json" )))
471472
472473 def test_text_to_image_checkpointing (self ):
@@ -778,7 +779,7 @@ def test_text_to_image_sdxl(self):
778779
779780 run_command (self ._launch_args + test_args )
780781 # save_pretrained smoke test
781- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.bin " )))
782+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.safetensors " )))
782783 self .assertTrue (os .path .isfile (os .path .join (tmpdir , "scheduler" , "scheduler_config.json" )))
783784
784785 def test_text_to_image_lora_checkpointing_checkpoints_total_limit (self ):
@@ -1373,7 +1374,7 @@ def test_controlnet_sdxl(self):
13731374
13741375 run_command (self ._launch_args + test_args )
13751376
1376- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "diffusion_pytorch_model.bin " )))
1377+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "diffusion_pytorch_model.safetensors " )))
13771378
13781379 def test_custom_diffusion_checkpointing_checkpoints_total_limit (self ):
13791380 with tempfile .TemporaryDirectory () as tmpdir :
@@ -1390,6 +1391,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
13901391 --max_train_steps=6
13911392 --checkpoints_total_limit=2
13921393 --checkpointing_steps=2
1394+ --no_safe_serialization
13931395 """ .split ()
13941396
13951397 run_command (self ._launch_args + test_args )
@@ -1413,6 +1415,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
14131415 --dataloader_num_workers=0
14141416 --max_train_steps=9
14151417 --checkpointing_steps=2
1418+ --no_safe_serialization
14161419 """ .split ()
14171420
14181421 run_command (self ._launch_args + test_args )
@@ -1436,6 +1439,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
14361439 --checkpointing_steps=2
14371440 --resume_from_checkpoint=checkpoint-8
14381441 --checkpoints_total_limit=3
1442+ --no_safe_serialization
14391443 """ .split ()
14401444
14411445 run_command (self ._launch_args + resume_run_args )
@@ -1464,10 +1468,10 @@ def test_text_to_image_lora_sdxl(self):
14641468
14651469 run_command (self ._launch_args + test_args )
14661470 # save_pretrained smoke test
1467- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
1471+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
14681472
14691473 # make sure the state_dict has the correct naming in the parameters.
1470- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
1474+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
14711475 is_lora = all ("lora" in k for k in lora_state_dict .keys ())
14721476 self .assertTrue (is_lora )
14731477
@@ -1491,10 +1495,10 @@ def test_text_to_image_lora_sdxl_with_text_encoder(self):
14911495
14921496 run_command (self ._launch_args + test_args )
14931497 # save_pretrained smoke test
1494- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
1498+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
14951499
14961500 # make sure the state_dict has the correct naming in the parameters.
1497- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
1501+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
14981502 is_lora = all ("lora" in k for k in lora_state_dict .keys ())
14991503 self .assertTrue (is_lora )
15001504
0 commit comments