11import  gc 
22import  unittest 
33
4+ import  torch 
5+ 
46from  diffusers  import  (
57    SanaTransformer2DModel ,
68)
1820@require_torch_accelerator  
1921class  SanaTransformer2DModelSingleFileTests (unittest .TestCase ):
2022    model_class  =  SanaTransformer2DModel 
23+     ckpt_path  =  "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" 
24+     alternate_keys_ckpt_paths  =  [
25+         "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" 
26+     ]
2127
2228    repo_id  =  "Efficient-Large-Model/Sana_1600M_1024px_diffusers" 
2329
@@ -32,4 +38,22 @@ def tearDown(self):
3238        backend_empty_cache (torch_device )
3339
3440    def  test_single_file_components (self ):
35-         _  =  self .model_class .from_pretrained (self .repo_id , subfolder = "transformer" )
41+         model  =  self .model_class .from_pretrained (self .repo_id , subfolder = "transformer" )
42+         model_single_file  =  self .model_class .from_single_file (self .ckpt_path )
43+ 
44+         PARAMS_TO_IGNORE  =  ["torch_dtype" , "_name_or_path" , "_use_default_values" , "_diffusers_version" ]
45+         for  param_name , param_value  in  model_single_file .config .items ():
46+             if  param_name  in  PARAMS_TO_IGNORE :
47+                 continue 
48+             assert  (
49+                 model .config [param_name ] ==  param_value 
50+             ), f"{ param_name }  
51+ 
52+     def  test_checkpoint_loading (self ):
53+         for  ckpt_path  in  self .alternate_keys_ckpt_paths :
54+             torch .cuda .empty_cache ()
55+             model  =  self .model_class .from_single_file (ckpt_path )
56+ 
57+             del  model 
58+             gc .collect ()
59+             torch .cuda .empty_cache ()
0 commit comments