- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.4k
[Core] fix variant-identification. #9253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
6b379a9
              f155ec7
              3f36e59
              91253e8
              dd5941e
              564b8b4
              fdd0435
              d5cad9e
              c0b1ceb
              247dd93
              b024a6d
              fdfdc5f
              dcf1852
              3a71ad9
              ab91852
              aa631c5
              453bfa5
              11e4b71
              dbdf0f9
              671038a
              57382f2
              ea5ecdb
              a510a9b
              f583dad
              dc0255a
              f2ab3de
              10baa9d
              25ac01f
              bac62ac
              b6794ed
              fcb4e39
              4c0c5d2
              0b1c2a6
              8ad6b23
              1190f7d
              59cfefb
              d72f5c1
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|  | @@ -50,7 +50,6 @@ | |||
| DEPRECATED_REVISION_ARGS, | ||||
| BaseOutput, | ||||
| PushToHubMixin, | ||||
| deprecate, | ||||
| is_accelerate_available, | ||||
| is_accelerate_version, | ||||
| is_torch_npu_available, | ||||
|  | @@ -722,6 +721,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
| ) | ||||
| else: | ||||
| cached_folder = pretrained_model_name_or_path | ||||
| filenames = [] | ||||
|          | ||||
| for _, _, files in os.walk(cached_folder): | ||||
| for file in files: | ||||
| filenames.append(os.path.basename(file)) | ||||
|  | ||||
| model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
| if len(variant_filenames) == 0 and variant is not None: | ||||
| error_message = ( | ||||
| f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
| f" Available ones are: {model_filenames}." | ||||
| ) | ||||
| raise ValueError(error_message) | ||||
|  | ||||
| config_dict = cls.load_config(cached_folder) | ||||
|  | ||||
|  | @@ -1239,6 +1250,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
| model_info_call_error = e # save error to reraise it if model is not cached locally | ||||
|  | ||||
| if not local_files_only: | ||||
| filenames = {sibling.rfilename for sibling in info.siblings} | ||||
| model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
| if len(variant_filenames) == 0 and variant is not None: | ||||
| error_message = ( | ||||
| f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
| f" Available ones are: {model_filenames}." | ||||
| ) | ||||
| raise ValueError(error_message) | ||||
|  | ||||
| config_file = hf_hub_download( | ||||
| pretrained_model_name, | ||||
| cls.config_name, | ||||
|  | @@ -1255,9 +1275,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
| # retrieve all folder_names that contain relevant files | ||||
| folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] | ||||
|  | ||||
| filenames = {sibling.rfilename for sibling in info.siblings} | ||||
| model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
|  | ||||
| 
      Comment on lines
    
      -1270
     to 
      -1272
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was moved up to raise error earlier in code. | ||||
| diffusers_module = importlib.import_module(__name__.split(".")[0]) | ||||
| pipelines = getattr(diffusers_module, "pipelines") | ||||
|  | ||||
|  | @@ -1279,15 +1296,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
| f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." | ||||
| ) | ||||
|  | ||||
| if len(variant_filenames) == 0 and variant is not None: | ||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's not remove this error in  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not an error, though. It's a deprecation. Do we exactly want to keep it that way? If so, we will have to remove it anyway because the deprecation is supposed to expire after "0.24.0" version. Instead, we are erroring out now from  
 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah got it. I think this should be resolved now. WDYT about catching these errors without having to download the actual files and leveraging  This could live in a future PR.         
                  sayakpaul marked this conversation as resolved.
              Show resolved
            Hide resolved | ||||
| deprecation_message = ( | ||||
| f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
| f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" | ||||
| "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant" | ||||
| "modeling files is deprecated." | ||||
| ) | ||||
| deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False) | ||||
|  | ||||
|         
                  sayakpaul marked this conversation as resolved.
              Show resolved
            Hide resolved | ||||
| # remove ignored filenames | ||||
| model_filenames = set(model_filenames) - set(ignore_filenames) | ||||
| variant_filenames = set(variant_filenames) - set(ignore_filenames) | ||||
|  | @@ -1356,6 +1364,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
| ) | ||||
| expected_components, _ = cls._get_signature_keys(pipeline_class) | ||||
| passed_components = [k for k in expected_components if k in kwargs] | ||||
| is_sharded = any("index.json" in f and f != "model_index.json" for f in filenames) | ||||
|  | ||||
| if ( | ||||
| use_safetensors | ||||
|  | @@ -1380,9 +1389,13 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
|  | ||||
| safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} | ||||
| safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} | ||||
| # `not is_sharded` because sharded checkpoints with a variant | ||||
|         
                  sayakpaul marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||||
| # ("fp16") for example may have lesser shards actually. Consider | ||||
| # https://huggingface.co/fal/AuraFlow/tree/main/transformer, for example. | ||||
|         
                  sayakpaul marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||||
| if ( | ||||
| len(safetensors_variant_filenames) > 0 | ||||
| and safetensors_model_filenames != safetensors_variant_filenames | ||||
| and not is_sharded | ||||
| ): | ||||
| logger.warning( | ||||
| f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." | ||||
|  | @@ -1431,6 +1444,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
|  | ||||
| # download all allow_patterns - ignore_patterns | ||||
| try: | ||||
| # print(f"{sorted(allow_patterns)=}") | ||||
|         
                  sayakpaul marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||||
| cached_folder = snapshot_download( | ||||
| pretrained_model_name, | ||||
| cache_dir=cache_dir, | ||||
|  | ||||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -551,6 +551,29 @@ def test_download_variant_partly(self): | |
| assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 | ||
| assert not any(f.endswith(other_format) for f in files) | ||
|  | ||
| def test_download_variants_with_sharded_checkpoints(self): | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LMK if someone has a better idea to test it out. | ||
| # Here we test for downloading of "variant" files belonging to the `unet` and | ||
| # the `text_encoder`. Their checkpoints can be sharded. | ||
| for use_safetensors in [True, False]: | ||
| for variant in ["fp16", None]: | ||
| with tempfile.TemporaryDirectory() as tmpdirname: | ||
| tmpdirname = DiffusionPipeline.download( | ||
| "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", | ||
| safety_checker=None, | ||
| cache_dir=tmpdirname, | ||
| variant=variant, | ||
| use_safetensors=use_safetensors, | ||
| ) | ||
|  | ||
| all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] | ||
| files = [item for sublist in all_root_files for item in sublist] | ||
|  | ||
| # Check for `model_ext` and `variant`. | ||
| model_ext = ".safetensors" if use_safetensors else ".bin" | ||
| unexpected_ext = ".bin" if use_safetensors else ".safetensors" | ||
| assert not any(f.endswith(unexpected_ext) for f in files) | ||
| assert all(variant in f for f in files if f.endswith(model_ext) and variant is not None) | ||
|  | ||
| def test_download_safetensors_only_variant_exists_for_model(self): | ||
| variant = None | ||
| use_safetensors = True | ||
|  | @@ -655,7 +678,7 @@ def test_local_save_load_index(self): | |
| out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="np").images | ||
|  | ||
| with tempfile.TemporaryDirectory() as tmpdirname: | ||
| pipe.save_pretrained(tmpdirname) | ||
| pipe.save_pretrained(tmpdirname, variant=variant, safe_serialization=use_safe) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should have been serialized with  | ||
| pipe_2 = StableDiffusionPipeline.from_pretrained( | ||
| tmpdirname, safe_serialization=use_safe, variant=variant | ||
| ) | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -1859,6 +1859,74 @@ def callback_increase_guidance(pipe, i, t, callback_kwargs): | |
| # accounts for models that modify the number of inference steps based on strength | ||
| assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps) | ||
|  | ||
| def test_serialization_with_variants(self): | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| model_components = [ | ||
| component_name for component_name, component in pipe.components.items() if isinstance(component, nn.Module) | ||
| ] | ||
| variant = "fp16" | ||
|  | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) | ||
|  | ||
| with open(f"{tmpdir}/model_index.json", "r") as f: | ||
| config = json.load(f) | ||
|  | ||
| for subfolder in os.listdir(tmpdir): | ||
| if not os.path.isfile(subfolder) and subfolder in model_components: | ||
| folder_path = os.path.join(tmpdir, subfolder) | ||
| is_folder = os.path.isdir(folder_path) and subfolder in config | ||
| assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)) | ||
|  | ||
| def test_loading_with_variants(self): | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| variant = "fp16" | ||
|  | ||
| def is_nan(tensor): | ||
| if tensor.ndimension() == 0: | ||
| has_nan = torch.isnan(tensor).item() | ||
| else: | ||
| has_nan = torch.isnan(tensor).any() | ||
| return has_nan | ||
|  | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) | ||
| pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, variant=variant) | ||
|  | ||
| model_components_pipe = { | ||
| component_name: component | ||
| for component_name, component in pipe.components.items() | ||
| if isinstance(component, nn.Module) | ||
| } | ||
| model_components_pipe_loaded = { | ||
| component_name: component | ||
| for component_name, component in pipe_loaded.components.items() | ||
| if isinstance(component, nn.Module) | ||
| } | ||
| for component_name in model_components_pipe: | ||
| pipe_component = model_components_pipe[component_name] | ||
| pipe_loaded_component = model_components_pipe_loaded[component_name] | ||
| for p1, p2 in zip(pipe_component.parameters(), pipe_loaded_component.parameters()): | ||
| # nan check for luminanext (mps). | ||
| if not (is_nan(p1) and is_nan(p2)): | ||
| self.assertTrue(torch.equal(p1, p2)) | ||
|  | ||
| def test_loading_with_incorrect_variants_raises_error(self): | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| variant = "fp16" | ||
|  | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| # Don't save with variants. | ||
| pipe.save_pretrained(tmpdir, safe_serialization=False) | ||
|  | ||
| with self.assertRaises(ValueError) as error: | ||
| _ = self.pipeline_class.from_pretrained(tmpdir, variant=variant) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would have failed with the fixes from this PR rightfully complaining: We didn't have it because we never tested it. But we should be all good now. | ||
|  | ||
| assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception) | ||
|  | ||
| def test_StableDiffusionMixin_component(self): | ||
| """Any pipeline that have LDMFuncMixin should have vae and unet components.""" | ||
| if not issubclass(self.pipeline_class, StableDiffusionMixin): | ||
|  | ||
Uh oh!
There was an error while loading. Please reload this page.