@@ -68,25 +68,21 @@ def test_all_is_compatible_variant(self):
6868            "unet/diffusion_pytorch_model.fp16.bin" ,
6969            "unet/diffusion_pytorch_model.fp16.safetensors" ,
7070        ]
71-         variant  =  "fp16" 
72-         self .assertTrue (is_safetensors_compatible (filenames , variant = variant ))
71+         self .assertTrue (is_safetensors_compatible (filenames ))
7372
7473    def  test_diffusers_model_is_compatible_variant (self ):
7574        filenames  =  [
7675            "unet/diffusion_pytorch_model.fp16.bin" ,
7776            "unet/diffusion_pytorch_model.fp16.safetensors" ,
7877        ]
79-         variant  =  "fp16" 
80-         self .assertTrue (is_safetensors_compatible (filenames , variant = variant ))
78+         self .assertTrue (is_safetensors_compatible (filenames ))
8179
82-     def  test_diffusers_model_is_compatible_variant_partial (self ):
83-         # pass variant but use the non-variant filenames 
80+     def  test_diffusers_model_is_compatible_variant_mixed (self ):
8481        filenames  =  [
8582            "unet/diffusion_pytorch_model.bin" ,
86-             "unet/diffusion_pytorch_model.safetensors" ,
83+             "unet/diffusion_pytorch_model.fp16. safetensors" ,
8784        ]
88-         variant  =  "fp16" 
89-         self .assertTrue (is_safetensors_compatible (filenames , variant = variant ))
85+         self .assertTrue (is_safetensors_compatible (filenames ))
9086
9187    def  test_diffusers_model_is_not_compatible_variant (self ):
9288        filenames  =  [
@@ -99,25 +95,14 @@ def test_diffusers_model_is_not_compatible_variant(self):
9995            "unet/diffusion_pytorch_model.fp16.bin" ,
10096            # Removed: 'unet/diffusion_pytorch_model.fp16.safetensors', 
10197        ]
102-         variant  =  "fp16" 
103-         self .assertFalse (is_safetensors_compatible (filenames , variant = variant ))
98+         self .assertFalse (is_safetensors_compatible (filenames ))
10499
105100    def  test_transformer_model_is_compatible_variant (self ):
106101        filenames  =  [
107102            "text_encoder/pytorch_model.fp16.bin" ,
108103            "text_encoder/model.fp16.safetensors" ,
109104        ]
110-         variant  =  "fp16" 
111-         self .assertTrue (is_safetensors_compatible (filenames , variant = variant ))
112- 
113-     def  test_transformer_model_is_compatible_variant_partial (self ):
114-         # pass variant but use the non-variant filenames 
115-         filenames  =  [
116-             "text_encoder/pytorch_model.bin" ,
117-             "text_encoder/model.safetensors" ,
118-         ]
119-         variant  =  "fp16" 
120-         self .assertTrue (is_safetensors_compatible (filenames , variant = variant ))
105+         self .assertTrue (is_safetensors_compatible (filenames ))
121106
122107    def  test_transformer_model_is_not_compatible_variant (self ):
123108        filenames  =  [
@@ -126,9 +111,45 @@ def test_transformer_model_is_not_compatible_variant(self):
126111            "vae/diffusion_pytorch_model.fp16.bin" ,
127112            "vae/diffusion_pytorch_model.fp16.safetensors" ,
128113            "text_encoder/pytorch_model.fp16.bin" ,
129-             # 'text_encoder/model.fp16.safetensors', 
130114            "unet/diffusion_pytorch_model.fp16.bin" ,
131115            "unet/diffusion_pytorch_model.fp16.safetensors" ,
132116        ]
133-         variant  =  "fp16" 
134-         self .assertFalse (is_safetensors_compatible (filenames , variant = variant ))
117+         self .assertFalse (is_safetensors_compatible (filenames ))
118+ 
119+     def  test_transformers_is_compatible_sharded (self ):
120+         filenames  =  [
121+             "text_encoder/pytorch_model.bin" ,
122+             "text_encoder/model-00001-of-00002.safetensors" ,
123+             "text_encoder/model-00002-of-00002.safetensors" ,
124+         ]
125+         self .assertTrue (is_safetensors_compatible (filenames ))
126+ 
127+     def  test_transformers_is_compatible_variant_sharded (self ):
128+         filenames  =  [
129+             "text_encoder/pytorch_model.bin" ,
130+             "text_encoder/model.fp16-00001-of-00002.safetensors" ,
131+             "text_encoder/model.fp16-00001-of-00002.safetensors" ,
132+         ]
133+         self .assertTrue (is_safetensors_compatible (filenames ))
134+ 
135+     def  test_diffusers_is_compatible_sharded (self ):
136+         filenames  =  [
137+             "unet/diffusion_pytorch_model.bin" ,
138+             "unet/diffusion_pytorch_model-00001-of-00002.safetensors" ,
139+             "unet/diffusion_pytorch_model-00002-of-00002.safetensors" ,
140+         ]
141+         self .assertTrue (is_safetensors_compatible (filenames ))
142+ 
143+     def  test_diffusers_is_compatible_variant_sharded (self ):
144+         filenames  =  [
145+             "unet/diffusion_pytorch_model.bin" ,
146+             "unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors" ,
147+             "unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors" ,
148+         ]
149+         self .assertTrue (is_safetensors_compatible (filenames ))
150+ 
151+     def  test_diffusers_is_compatible_only_variants (self ):
152+         filenames  =  [
153+             "unet/diffusion_pytorch_model.fp16.safetensors" ,
154+         ]
155+         self .assertTrue (is_safetensors_compatible (filenames ))
0 commit comments