diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index dd542145d3fa..c4085f6f2055 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -33,6 +33,7 @@ ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, + _maybe_remap_transformers_class, deprecate, get_class_from_dynamic_module, is_accelerate_available, @@ -356,6 +357,11 @@ def maybe_raise_or_warn( """Simple helper method to raise or warn in case incorrect module has been passed""" if not is_pipeline_module: library = importlib.import_module(library_name) + + # Handle deprecated Transformers classes + if library_name == "transformers": + class_name = _maybe_remap_transformers_class(class_name) or class_name + class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} @@ -390,6 +396,11 @@ def simple_get_class_obj(library_name, class_name): class_obj = getattr(pipeline_module, class_name) else: library = importlib.import_module(library_name) + + # Handle deprecated Transformers classes + if library_name == "transformers": + class_name = _maybe_remap_transformers_class(class_name) or class_name + class_obj = getattr(library, class_name) return class_obj @@ -416,6 +427,10 @@ def get_class_obj_and_candidates( # else we just import it from the library. library = importlib.import_module(library_name) + # Handle deprecated Transformers classes + if library_name == "transformers": + class_name = _maybe_remap_transformers_class(class_name) or class_name + class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 63932221b207..d8e1a5540100 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -38,7 +38,7 @@ WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ) -from .deprecation_utils import deprecate +from .deprecation_utils import _maybe_remap_transformers_class, deprecate from .doc_utils import replace_example_docstring from .dynamic_modules_utils import get_class_from_dynamic_module from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py index 4f001b3047d6..d76623541b9f 100644 --- a/src/diffusers/utils/deprecation_utils.py +++ b/src/diffusers/utils/deprecation_utils.py @@ -4,6 +4,54 @@ from packaging import version +from ..utils import logging + + +logger = logging.get_logger(__name__) + +# Mapping for deprecated Transformers classes to their replacements +# This is used to handle models that reference deprecated class names in their configs +# Reference: https://github.com/huggingface/transformers/issues/40822 +# Format: { +# "DeprecatedClassName": { +# "new_class": "NewClassName", +# "transformers_version": (">=", "5.0.0"), # (operation, version) tuple +# } +# } +_TRANSFORMERS_CLASS_REMAPPING = { + "CLIPFeatureExtractor": { + "new_class": "CLIPImageProcessor", + "transformers_version": (">", "4.57.0"), + }, +} + + +def _maybe_remap_transformers_class(class_name: str) -> Optional[str]: + """ + Check if a Transformers class should be remapped to a newer version. + + Args: + class_name: The name of the class to check + + Returns: + The new class name if remapping should occur, None otherwise + """ + if class_name not in _TRANSFORMERS_CLASS_REMAPPING: + return None + + from .import_utils import is_transformers_version + + mapping = _TRANSFORMERS_CLASS_REMAPPING[class_name] + operation, required_version = mapping["transformers_version"] + + # Only remap if the transformers version meets the requirement + if is_transformers_version(operation, required_version): + new_class = mapping["new_class"] + logger.warning(f"{class_name} appears to have been deprecated in transformers. Using {new_class} instead.") + return mapping["new_class"] + + return None + def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): from .. import __version__