| 
18 | 18 |     StableDiffusionPipeline,  | 
19 | 19 |     UNet2DConditionModel,  | 
20 | 20 | )  | 
21 |  | -from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible  | 
 | 21 | +from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings  | 
22 | 22 | from diffusers.utils.testing_utils import torch_device  | 
23 | 23 | 
 
  | 
24 | 24 | 
 
  | 
@@ -210,6 +210,135 @@ def test_diffusers_is_compatible_no_components_only_variants(self):  | 
210 | 210 |         self.assertFalse(is_safetensors_compatible(filenames))  | 
211 | 211 | 
 
  | 
212 | 212 | 
 
  | 
 | 213 | +class VariantCompatibleSiblingsTest(unittest.TestCase):  | 
 | 214 | +    def test_only_non_variants_downloaded(self):  | 
 | 215 | +        variant = "fp16"  | 
 | 216 | +        filenames = [  | 
 | 217 | +            f"vae/diffusion_pytorch_model.{variant}.safetensors",  | 
 | 218 | +            "vae/diffusion_pytorch_model.safetensors",  | 
 | 219 | +            f"text_encoder/model.{variant}.safetensors",  | 
 | 220 | +            "text_encoder/model.safetensors",  | 
 | 221 | +            f"unet/diffusion_pytorch_model.{variant}.safetensors",  | 
 | 222 | +            "unet/diffusion_pytorch_model.safetensors",  | 
 | 223 | +        ]  | 
 | 224 | + | 
 | 225 | +        model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)  | 
 | 226 | +        assert all(variant not in f for f in model_filenames)  | 
 | 227 | + | 
 | 228 | +    def test_only_variants_downloaded(self):  | 
 | 229 | +        variant = "fp16"  | 
 | 230 | +        filenames = [  | 
 | 231 | +            f"vae/diffusion_pytorch_model.{variant}.safetensors",  | 
 | 232 | +            "vae/diffusion_pytorch_model.safetensors",  | 
 | 233 | +            f"text_encoder/model.{variant}.safetensors",  | 
 | 234 | +            "text_encoder/model.safetensors",  | 
 | 235 | +            f"unet/diffusion_pytorch_model.{variant}.safetensors",  | 
 | 236 | +            "unet/diffusion_pytorch_model.safetensors",  | 
 | 237 | +        ]  | 
 | 238 | + | 
 | 239 | +        model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)  | 
 | 240 | +        assert all(variant in f for f in model_filenames)  | 
 | 241 | + | 
 | 242 | +    def test_mixed_variants_downloaded(self):  | 
 | 243 | +        variant = "fp16"  | 
 | 244 | +        non_variant_file = "text_encoder/model.safetensors"  | 
 | 245 | +        filenames = [  | 
 | 246 | +            f"vae/diffusion_pytorch_model.{variant}.safetensors",  | 
 | 247 | +            "vae/diffusion_pytorch_model.safetensors",  | 
 | 248 | +            "text_encoder/model.safetensors",  | 
 | 249 | +            f"unet/diffusion_pytorch_model.{variant}.safetensors",  | 
 | 250 | +            "unet/diffusion_pytorch_model.safetensors",  | 
 | 251 | +        ]  | 
 | 252 | +        model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)  | 
 | 253 | +        assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)  | 
 | 254 | + | 
 | 255 | +    def test_non_variants_in_main_dir_downloaded(self):  | 
 | 256 | +        variant = "fp16"  | 
 | 257 | +        filenames = [  | 
 | 258 | +            f"diffusion_pytorch_model.{variant}.safetensors",  | 
 | 259 | +            "diffusion_pytorch_model.safetensors",  | 
 | 260 | +            "model.safetensors",  | 
 | 261 | +            f"model.{variant}.safetensors",  | 
 | 262 | +            f"diffusion_pytorch_model.{variant}.safetensors",  | 
 | 263 | +            "diffusion_pytorch_model.safetensors",  | 
 | 264 | +        ]  | 
 | 265 | +        model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)  | 
 | 266 | +        assert all(variant not in f for f in model_filenames)  | 
 | 267 | + | 
 | 268 | +    def test_variants_in_main_dir_downloaded(self):  | 
 | 269 | +        variant = "fp16"  | 
 | 270 | +        filenames = [  | 
 | 271 | +            f"diffusion_pytorch_model.{variant}.safetensors",  | 
 | 272 | +            "diffusion_pytorch_model.safetensors",  | 
 | 273 | +            "model.safetensors",  | 
 | 274 | +            f"model.{variant}.safetensors",  | 
 | 275 | +            f"diffusion_pytorch_model.{variant}.safetensors",  | 
 | 276 | +            "diffusion_pytorch_model.safetensors",  | 
 | 277 | +        ]  | 
 | 278 | +        model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)  | 
 | 279 | +        assert all(variant in f for f in model_filenames)  | 
 | 280 | + | 
 | 281 | +    def test_mixed_variants_in_main_dir_downloaded(self):  | 
 | 282 | +        variant = "fp16"  | 
 | 283 | +        non_variant_file = "model.safetensors"  | 
 | 284 | +        filenames = [  | 
 | 285 | +            f"diffusion_pytorch_model.{variant}.safetensors",  | 
 | 286 | +            "diffusion_pytorch_model.safetensors",  | 
 | 287 | +            "model.safetensors",  | 
 | 288 | +            f"diffusion_pytorch_model.{variant}.safetensors",  | 
 | 289 | +            "diffusion_pytorch_model.safetensors",  | 
 | 290 | +        ]  | 
 | 291 | +        model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)  | 
 | 292 | +        assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)  | 
 | 293 | + | 
 | 294 | +    def test_sharded_non_variants_downloaded(self):  | 
 | 295 | +        variant = "fp16"  | 
 | 296 | +        filenames = [  | 
 | 297 | +            f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",  | 
 | 298 | +            "unet/diffusion_pytorch_model.safetensors.index.json",  | 
 | 299 | +            "unet/diffusion_pytorch_model-00001-of-00003.safetensors",  | 
 | 300 | +            "unet/diffusion_pytorch_model-00002-of-00003.safetensors",  | 
 | 301 | +            "unet/diffusion_pytorch_model-00003-of-00003.safetensors",  | 
 | 302 | +            f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",  | 
 | 303 | +            f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",  | 
 | 304 | +        ]  | 
 | 305 | +        model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)  | 
 | 306 | +        assert all(variant not in f for f in model_filenames)  | 
 | 307 | + | 
 | 308 | +    def test_sharded_variants_downloaded(self):  | 
 | 309 | +        variant = "fp16"  | 
 | 310 | +        filenames = [  | 
 | 311 | +            f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",  | 
 | 312 | +            "unet/diffusion_pytorch_model.safetensors.index.json",  | 
 | 313 | +            "unet/diffusion_pytorch_model-00001-of-00003.safetensors",  | 
 | 314 | +            "unet/diffusion_pytorch_model-00002-of-00003.safetensors",  | 
 | 315 | +            "unet/diffusion_pytorch_model-00003-of-00003.safetensors",  | 
 | 316 | +            f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",  | 
 | 317 | +            f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",  | 
 | 318 | +        ]  | 
 | 319 | +        model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)  | 
 | 320 | +        assert all(variant in f for f in model_filenames)  | 
 | 321 | + | 
 | 322 | +    def test_sharded_mixed_variants_downloaded(self):  | 
 | 323 | +        variant = "fp16"  | 
 | 324 | +        allowed_non_variant = "unet"  | 
 | 325 | +        filenames = [  | 
 | 326 | +            f"vae/diffusion_pytorch_model.safetensors.index.{variant}.json",  | 
 | 327 | +            "vae/diffusion_pytorch_model.safetensors.index.json",  | 
 | 328 | +            "unet/diffusion_pytorch_model.safetensors.index.json",  | 
 | 329 | +            "unet/diffusion_pytorch_model-00001-of-00003.safetensors",  | 
 | 330 | +            "unet/diffusion_pytorch_model-00002-of-00003.safetensors",  | 
 | 331 | +            "unet/diffusion_pytorch_model-00003-of-00003.safetensors",  | 
 | 332 | +            f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",  | 
 | 333 | +            f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",  | 
 | 334 | +            "vae/diffusion_pytorch_model-00001-of-00003.safetensors",  | 
 | 335 | +            "vae/diffusion_pytorch_model-00002-of-00003.safetensors",  | 
 | 336 | +            "vae/diffusion_pytorch_model-00003-of-00003.safetensors",  | 
 | 337 | +        ]  | 
 | 338 | +        model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)  | 
 | 339 | +        assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)  | 
 | 340 | + | 
 | 341 | + | 
213 | 342 | class ProgressBarTests(unittest.TestCase):  | 
214 | 343 |     def get_dummy_components_image_generation(self):  | 
215 | 344 |         cross_attention_dim = 8  | 
 | 
0 commit comments