-
Couldn't load subscription status.
- Fork 6.5k
Check correct model type is passed to from_pretrained
#10189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
2fd4508
78c6e68
185a78f
679c18c
6aad7a7
c1db3bd
50b740a
b8fa81a
44f24a4
5af8c7f
baea141
c5e1e2d
dba12b6
99b0f92
3a43c8a
803e33f
c81415b
13a824e
5679067
24d79a3
3f841d5
f18687f
87f8f03
56ac8b4
87dcf54
e193563
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import enum | ||
| import fnmatch | ||
| import importlib | ||
| import inspect | ||
|
|
@@ -45,7 +46,7 @@ | |
| from ..models.attention_processor import FusedAttnProcessor2_0 | ||
| from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin | ||
| from ..quantizers.bitsandbytes.utils import _check_bnb_status | ||
| from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME | ||
| from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin | ||
| from ..utils import ( | ||
| CONFIG_NAME, | ||
| DEPRECATED_REVISION_ARGS, | ||
|
|
@@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| # in this case they are already instantiated in `kwargs` | ||
| # extract them here | ||
| expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) | ||
| expected_types = pipeline_class._get_signature_types() | ||
| passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} | ||
| passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} | ||
| init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) | ||
|
|
@@ -833,6 +835,31 @@ def load_module(name, value): | |
|
|
||
| init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} | ||
|
|
||
| for key in init_dict.keys(): | ||
| if key not in passed_class_obj: | ||
| continue | ||
|
|
||
| class_obj = passed_class_obj[key] | ||
| _expected_class_types = [] | ||
| for expected_type in expected_types[key]: | ||
| if isinstance(expected_type, enum.EnumMeta): | ||
| _expected_class_types.extend(expected_type.__members__.keys()) | ||
| else: | ||
| _expected_class_types.append(expected_type.__name__) | ||
|
|
||
| _is_valid_type = class_obj.__class__.__name__ in _expected_class_types | ||
| if isinstance(class_obj, SchedulerMixin) and not _is_valid_type: | ||
| _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types) | ||
|
||
| _is_flow_match = "FlowMatch" in class_obj.__class__.__name__ | ||
| if _requires_flow_match and not _is_flow_match: | ||
| logger.warning(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.") | ||
|
||
| elif not _requires_flow_match and _is_flow_match: | ||
| logger.warning(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.") | ||
| elif not _is_valid_type: | ||
| logger.warning( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if it's not a scheduler and the types don't match it's okay to raise an error. I think it would break in the model loading step anyway in this case. wdyt @yiyixuxu? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer a warning because:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we just added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've removed scheduler related changes for now, I think we can revisit that later, as @yiyixuxu mentioned above type hints haven't been strictly enforced there are probably some missing/wrong, especially for schedulers. Warning is better because of that too, if there is some wrong type hint that makes its way into a release we'd have to issue a hotfix release to fix it, that just creates headaches and issue reports. |
||
| f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}." | ||
| ) | ||
|
|
||
| # Special case: safety_checker must be loaded separately when using `from_flax` | ||
| if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: | ||
| raise NotImplementedError( | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -47,6 +47,8 @@ | |||||||||||
| DPMSolverMultistepScheduler, | ||||||||||||
| EulerAncestralDiscreteScheduler, | ||||||||||||
| EulerDiscreteScheduler, | ||||||||||||
| FlowMatchEulerDiscreteScheduler, | ||||||||||||
| FluxPipeline, | ||||||||||||
| LMSDiscreteScheduler, | ||||||||||||
| ModelMixin, | ||||||||||||
| PNDMScheduler, | ||||||||||||
|
|
@@ -1802,6 +1804,42 @@ def test_pipe_same_device_id_offload(self): | |||||||||||
| sd.maybe_free_model_hooks() | ||||||||||||
| assert sd._offload_gpu_id == 5 | ||||||||||||
|
|
||||||||||||
| def test_wrong_model(self): | ||||||||||||
| tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") | ||||||||||||
| with self.assertRaises(ValueError) as error_context: | ||||||||||||
| _ = StableDiffusionPipeline.from_pretrained( | ||||||||||||
| "hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| assert "is of type" in str(error_context.exception) | ||||||||||||
| assert "but should be" in str(error_context.exception) | ||||||||||||
|
Comment on lines
+1806
to
+1813
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're now using warning, but for this case, diffusers/src/diffusers/pipelines/pipeline_utils.py Lines 893 to 897 in 6324340
So it's a little inconsistent and needs further testing to determine which other cases this already applies to. |
||||||||||||
|
|
||||||||||||
| def test_wrong_model_scheduler_type(self): | ||||||||||||
| scheduler = EulerDiscreteScheduler.from_pretrained("hf-internal-testing/tiny-flux-pipe", subfolder="scheduler") | ||||||||||||
| with self.assertLogs( | ||||||||||||
| logging.get_logger("diffusers.pipelines.pipeline_utils"), level="WARNING" | ||||||||||||
| ) as warning_context: | ||||||||||||
| _ = FluxPipeline.from_pretrained("hf-internal-testing/tiny-flux-pipe", scheduler=scheduler) | ||||||||||||
|
|
||||||||||||
| assert any("Expected" in message for message in warning_context.output) | ||||||||||||
| assert any("scheduler" in message for message in warning_context.output) | ||||||||||||
| assert any("EulerDiscreteScheduler" in message for message in warning_context.output) | ||||||||||||
|
|
||||||||||||
| def test_wrong_model_scheduler_enum(self): | ||||||||||||
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | ||||||||||||
| "hf-internal-testing/diffusers-stable-diffusion-tiny-all", subfolder="scheduler" | ||||||||||||
| ) | ||||||||||||
| with self.assertLogs( | ||||||||||||
| logging.get_logger("diffusers.pipelines.pipeline_utils"), level="WARNING" | ||||||||||||
| ) as warning_context: | ||||||||||||
| _ = StableDiffusionPipeline.from_pretrained( | ||||||||||||
| "hf-internal-testing/diffusers-stable-diffusion-tiny-all", scheduler=scheduler | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| assert any("Expected" in message for message in warning_context.output) | ||||||||||||
| assert any("scheduler" in message for message in warning_context.output) | ||||||||||||
| assert any("FlowMatchEulerDiscreteScheduler" in message for message in warning_context.output) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @slow | ||||||||||||
| @require_torch_gpu | ||||||||||||
|
|
||||||||||||
Uh oh!
There was an error while loading. Please reload this page.