-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[Core] fix variant-identification. #9253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
6b379a9
f155ec7
3f36e59
91253e8
dd5941e
564b8b4
fdd0435
d5cad9e
c0b1ceb
247dd93
b024a6d
fdfdc5f
dcf1852
3a71ad9
ab91852
aa631c5
453bfa5
11e4b71
dbdf0f9
671038a
57382f2
ea5ecdb
a510a9b
f583dad
dc0255a
f2ab3de
10baa9d
25ac01f
bac62ac
b6794ed
fcb4e39
4c0c5d2
0b1c2a6
8ad6b23
1190f7d
59cfefb
d72f5c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -50,7 +50,6 @@ | |||
| DEPRECATED_REVISION_ARGS, | ||||
| BaseOutput, | ||||
| PushToHubMixin, | ||||
| deprecate, | ||||
| is_accelerate_available, | ||||
| is_accelerate_version, | ||||
| is_torch_npu_available, | ||||
|
|
@@ -722,6 +721,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
| ) | ||||
| else: | ||||
| cached_folder = pretrained_model_name_or_path | ||||
| filenames = [] | ||||
|
||||
| for _, _, files in os.walk(cached_folder): | ||||
| for file in files: | ||||
| filenames.append(os.path.basename(file)) | ||||
|
|
||||
| model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
| if len(variant_filenames) == 0 and variant is not None: | ||||
| error_message = ( | ||||
| f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
| f" Available ones are: {model_filenames}." | ||||
| ) | ||||
| raise ValueError(error_message) | ||||
|
|
||||
| config_dict = cls.load_config(cached_folder) | ||||
|
|
||||
|
|
@@ -1239,6 +1250,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
| model_info_call_error = e # save error to reraise it if model is not cached locally | ||||
|
|
||||
| if not local_files_only: | ||||
| filenames = {sibling.rfilename for sibling in info.siblings} | ||||
| model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
| if len(variant_filenames) == 0 and variant is not None: | ||||
| error_message = ( | ||||
| f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
| f" Available ones are: {model_filenames}." | ||||
| ) | ||||
| raise ValueError(error_message) | ||||
|
|
||||
| config_file = hf_hub_download( | ||||
| pretrained_model_name, | ||||
| cls.config_name, | ||||
|
|
@@ -1255,9 +1275,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
| # retrieve all folder_names that contain relevant files | ||||
| folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] | ||||
|
|
||||
| filenames = {sibling.rfilename for sibling in info.siblings} | ||||
| model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
|
|
||||
|
Comment on lines
-1270
to
-1272
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was moved up to raise error earlier in code. |
||||
| diffusers_module = importlib.import_module(__name__.split(".")[0]) | ||||
| pipelines = getattr(diffusers_module, "pipelines") | ||||
|
|
||||
|
|
@@ -1279,15 +1296,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
| f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." | ||||
| ) | ||||
|
|
||||
| if len(variant_filenames) == 0 and variant is not None: | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's not remove this error in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not an error, though. It's a deprecation. Do we exactly want to keep it that way? If so, we will have to remove it anyway because the deprecation is supposed to expire after "0.24.0" version. Instead, we are erroring out now from
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah got it. I think this should be resolved now. WDYT about catching these errors without having to download the actual files and leveraging This could live in a future PR.
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
| deprecation_message = ( | ||||
| f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
| f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" | ||||
| "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant" | ||||
| "modeling files is deprecated." | ||||
| ) | ||||
| deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False) | ||||
|
|
||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
| # remove ignored filenames | ||||
| model_filenames = set(model_filenames) - set(ignore_filenames) | ||||
| variant_filenames = set(variant_filenames) - set(ignore_filenames) | ||||
|
|
@@ -1356,6 +1364,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
| ) | ||||
| expected_components, _ = cls._get_signature_keys(pipeline_class) | ||||
| passed_components = [k for k in expected_components if k in kwargs] | ||||
| is_sharded = any("index.json" in f and f != "model_index.json" for f in filenames) | ||||
|
|
||||
| if ( | ||||
| use_safetensors | ||||
|
|
@@ -1380,9 +1389,13 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
|
|
||||
| safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} | ||||
| safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} | ||||
| # `not is_sharded` because sharded checkpoints with a variant | ||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
| # ("fp16") for example may have lesser shards actually. Consider | ||||
| # https://huggingface.co/fal/AuraFlow/tree/main/transformer, for example. | ||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
| if ( | ||||
| len(safetensors_variant_filenames) > 0 | ||||
| and safetensors_model_filenames != safetensors_variant_filenames | ||||
| and not is_sharded | ||||
| ): | ||||
| logger.warning( | ||||
| f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." | ||||
|
|
@@ -1431,6 +1444,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
|
|
||||
| # download all allow_patterns - ignore_patterns | ||||
| try: | ||||
| # print(f"{sorted(allow_patterns)=}") | ||||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
| cached_folder = snapshot_download( | ||||
| pretrained_model_name, | ||||
| cache_dir=cache_dir, | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -551,6 +551,29 @@ def test_download_variant_partly(self): | |
| assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 | ||
| assert not any(f.endswith(other_format) for f in files) | ||
|
|
||
| def test_download_variants_with_sharded_checkpoints(self): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LMK if someone has a better idea to test it out. |
||
| # Here we test for downloading of "variant" files belonging to the `unet` and | ||
| # the `text_encoder`. Their checkpoints can be sharded. | ||
| for use_safetensors in [True, False]: | ||
| for variant in ["fp16", None]: | ||
| with tempfile.TemporaryDirectory() as tmpdirname: | ||
| tmpdirname = DiffusionPipeline.download( | ||
| "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", | ||
| safety_checker=None, | ||
| cache_dir=tmpdirname, | ||
| variant=variant, | ||
| use_safetensors=use_safetensors, | ||
| ) | ||
|
|
||
| all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] | ||
| files = [item for sublist in all_root_files for item in sublist] | ||
|
|
||
| # Check for `model_ext` and `variant`. | ||
| model_ext = ".safetensors" if use_safetensors else ".bin" | ||
| unexpected_ext = ".bin" if use_safetensors else ".safetensors" | ||
| assert not any(f.endswith(unexpected_ext) for f in files) | ||
| assert all(variant in f for f in files if f.endswith(model_ext) and variant is not None) | ||
|
|
||
| def test_download_safetensors_only_variant_exists_for_model(self): | ||
| variant = None | ||
| use_safetensors = True | ||
|
|
@@ -655,7 +678,7 @@ def test_local_save_load_index(self): | |
| out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="np").images | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdirname: | ||
| pipe.save_pretrained(tmpdirname) | ||
| pipe.save_pretrained(tmpdirname, variant=variant, safe_serialization=use_safe) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should have been serialized with |
||
| pipe_2 = StableDiffusionPipeline.from_pretrained( | ||
| tmpdirname, safe_serialization=use_safe, variant=variant | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1859,6 +1859,74 @@ def callback_increase_guidance(pipe, i, t, callback_kwargs): | |
| # accounts for models that modify the number of inference steps based on strength | ||
| assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps) | ||
|
|
||
| def test_serialization_with_variants(self): | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| model_components = [ | ||
| component_name for component_name, component in pipe.components.items() if isinstance(component, nn.Module) | ||
| ] | ||
| variant = "fp16" | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) | ||
|
|
||
| with open(f"{tmpdir}/model_index.json", "r") as f: | ||
| config = json.load(f) | ||
|
|
||
| for subfolder in os.listdir(tmpdir): | ||
| if not os.path.isfile(subfolder) and subfolder in model_components: | ||
| folder_path = os.path.join(tmpdir, subfolder) | ||
| is_folder = os.path.isdir(folder_path) and subfolder in config | ||
| assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)) | ||
|
|
||
| def test_loading_with_variants(self): | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| variant = "fp16" | ||
|
|
||
| def is_nan(tensor): | ||
| if tensor.ndimension() == 0: | ||
| has_nan = torch.isnan(tensor).item() | ||
| else: | ||
| has_nan = torch.isnan(tensor).any() | ||
| return has_nan | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) | ||
| pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, variant=variant) | ||
|
|
||
| model_components_pipe = { | ||
| component_name: component | ||
| for component_name, component in pipe.components.items() | ||
| if isinstance(component, nn.Module) | ||
| } | ||
| model_components_pipe_loaded = { | ||
| component_name: component | ||
| for component_name, component in pipe_loaded.components.items() | ||
| if isinstance(component, nn.Module) | ||
| } | ||
| for component_name in model_components_pipe: | ||
| pipe_component = model_components_pipe[component_name] | ||
| pipe_loaded_component = model_components_pipe_loaded[component_name] | ||
| for p1, p2 in zip(pipe_component.parameters(), pipe_loaded_component.parameters()): | ||
| # nan check for luminanext (mps). | ||
| if not (is_nan(p1) and is_nan(p2)): | ||
| self.assertTrue(torch.equal(p1, p2)) | ||
|
|
||
| def test_loading_with_incorrect_variants_raises_error(self): | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| variant = "fp16" | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| # Don't save with variants. | ||
| pipe.save_pretrained(tmpdir, safe_serialization=False) | ||
|
|
||
| with self.assertRaises(ValueError) as error: | ||
| _ = self.pipeline_class.from_pretrained(tmpdir, variant=variant) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would have failed with the fixes from this PR rightfully complaining: We didn't have it because we never tested it. But we should be all good now. |
||
|
|
||
| assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception) | ||
|
|
||
| def test_StableDiffusionMixin_component(self): | ||
| """Any pipeline that have LDMFuncMixin should have vae and unet components.""" | ||
| if not issubclass(self.pipeline_class, StableDiffusionMixin): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.