- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.4k
[Core] fix variant-identification. #9253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
6b379a9
              f155ec7
              3f36e59
              91253e8
              dd5941e
              564b8b4
              fdd0435
              d5cad9e
              c0b1ceb
              247dd93
              b024a6d
              fdfdc5f
              dcf1852
              3a71ad9
              ab91852
              aa631c5
              453bfa5
              11e4b71
              dbdf0f9
              671038a
              57382f2
              ea5ecdb
              a510a9b
              f583dad
              dc0255a
              f2ab3de
              10baa9d
              25ac01f
              bac62ac
              b6794ed
              fcb4e39
              4c0c5d2
              0b1c2a6
              8ad6b23
              1190f7d
              59cfefb
              d72f5c1
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|  | @@ -50,7 +50,6 @@ | |||
| DEPRECATED_REVISION_ARGS, | ||||
| BaseOutput, | ||||
| PushToHubMixin, | ||||
| deprecate, | ||||
| is_accelerate_available, | ||||
| is_accelerate_version, | ||||
| is_torch_npu_available, | ||||
|  | @@ -722,6 +721,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
| ) | ||||
| else: | ||||
| cached_folder = pretrained_model_name_or_path | ||||
| filenames = [] | ||||
| for _, _, files in os.walk(cached_folder): | ||||
| for file in files: | ||||
| filenames.append(os.path.basename(file)) | ||||
|  | ||||
| model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
| if len(variant_filenames) == 0 and variant is not None: | ||||
| error_message = ( | ||||
| f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
| f"Available ones are: {model_filenames}." | ||||
| ) | ||||
| raise ValueError(error_message) | ||||
|  | ||||
| config_dict = cls.load_config(cached_folder) | ||||
|  | ||||
|  | @@ -1239,6 +1250,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
| model_info_call_error = e # save error to reraise it if model is not cached locally | ||||
|  | ||||
| if not local_files_only: | ||||
| filenames = {sibling.rfilename for sibling in info.siblings} | ||||
| model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
| if len(variant_filenames) == 0 and variant is not None: | ||||
| error_message = ( | ||||
| f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
| f"Available ones are: {model_filenames}." | ||||
| ) | ||||
| raise ValueError(error_message) | ||||
|  | ||||
| config_file = hf_hub_download( | ||||
| pretrained_model_name, | ||||
| cls.config_name, | ||||
|  | @@ -1255,9 +1275,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
| # retrieve all folder_names that contain relevant files | ||||
| folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] | ||||
|  | ||||
| filenames = {sibling.rfilename for sibling in info.siblings} | ||||
| model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
|  | ||||
| 
      Comment on lines
    
      -1270
     to 
      -1272
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was moved up to raise error earlier in code. | ||||
| diffusers_module = importlib.import_module(__name__.split(".")[0]) | ||||
| pipelines = getattr(diffusers_module, "pipelines") | ||||
|  | ||||
|  | @@ -1279,15 +1296,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
| f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." | ||||
| ) | ||||
|  | ||||
| if len(variant_filenames) == 0 and variant is not None: | ||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's not remove this error in  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not an error, though. It's a deprecation. Do we exactly want to keep it that way? If so, we will have to remove it anyway because the deprecation is supposed to expire after "0.24.0" version. Instead, we are erroring out now from  
 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah got it. I think this should be resolved now. WDYT about catching these errors without having to download the actual files and leveraging  This could live in a future PR.         
                  sayakpaul marked this conversation as resolved.
              Show resolved
            Hide resolved | ||||
| deprecation_message = ( | ||||
| f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
| f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" | ||||
| "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant" | ||||
| "modeling files is deprecated." | ||||
| ) | ||||
| deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False) | ||||
|  | ||||
|         
                  sayakpaul marked this conversation as resolved.
              Show resolved
            Hide resolved | ||||
| # remove ignored filenames | ||||
| model_filenames = set(model_filenames) - set(ignore_filenames) | ||||
| variant_filenames = set(variant_filenames) - set(ignore_filenames) | ||||
|  | ||||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -1859,6 +1859,65 @@ def callback_increase_guidance(pipe, i, t, callback_kwargs): | |
| # accounts for models that modify the number of inference steps based on strength | ||
| assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps) | ||
|  | ||
| def test_serialization_with_variants(self): | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| model_components = [ | ||
| component_name for component_name, component in pipe.components.items() if isinstance(component, nn.Module) | ||
| ] | ||
| variant = "fp16" | ||
|  | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| pipe.save_pretrained(tmpdir, variant=variant) | ||
|  | ||
| with open(f"{tmpdir}/model_index.json", "r") as f: | ||
| config = json.load(f) | ||
|  | ||
| for subfolder in os.listdir(tmpdir): | ||
| if not os.path.isfile(subfolder) and subfolder in model_components: | ||
| folder_path = os.path.join(tmpdir, subfolder) | ||
| is_folder = os.path.isdir(folder_path) and subfolder in config | ||
| assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)) | ||
|  | ||
| def test_loading_with_variants(self): | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| variant = "fp16" | ||
|  | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| pipe.save_pretrained(tmpdir, variant=variant) | ||
| pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, variant=variant) | ||
|  | ||
| model_components_pipe = { | ||
| component_name: component | ||
| for component_name, component in pipe.components.items() | ||
| if isinstance(component, nn.Module) | ||
| } | ||
| model_components_pipe_loaded = { | ||
| component_name: component | ||
| for component_name, component in pipe_loaded.components.items() | ||
| if isinstance(component, nn.Module) | ||
| } | ||
| for component_name in model_components_pipe: | ||
| pipe_component = model_components_pipe[component_name] | ||
| pipe_loaded_component = model_components_pipe_loaded[component_name] | ||
| for p1, p2 in zip(pipe_component.parameters(), pipe_loaded_component.parameters()): | ||
| self.assertTrue(torch.equal(p1, p2)) | ||
|  | ||
| def test_loading_with_incorrect_variants_raises_error(self): | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| variant = "fp16" | ||
|  | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| # Don't save with variants. | ||
| pipe.save_pretrained(tmpdir) | ||
|  | ||
| with self.assertRaises(ValueError) as error: | ||
| _ = self.pipeline_class.from_pretrained(tmpdir, variant=variant) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would have failed with the fixes from this PR rightfully complaining: We didn't have it because we never tested it. But we should be all good now. | ||
|  | ||
| assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception) | ||
|  | ||
| def test_StableDiffusionMixin_component(self): | ||
| """Any pipeline that have LDMFuncMixin should have vae and unet components.""" | ||
| if not issubclass(self.pipeline_class, StableDiffusionMixin): | ||
|  | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think maybe we shoud just update the
_identify_model_variantsfunction usingvariant_compatible_siblingsand it is still not able to load variants with shared checkpoints from pipeline level
i.e. we should be able to load the fp16 variant in the transformer folder too but it is currently not
you get
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @DN6 @a-r-r-o-w here too