Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
_should_remap_transformers_class,
deprecate,
get_class_from_dynamic_module,
is_accelerate_available,
Expand Down Expand Up @@ -356,6 +357,11 @@ def maybe_raise_or_warn(
"""Simple helper method to raise or warn in case incorrect module has been passed"""
if not is_pipeline_module:
library = importlib.import_module(library_name)

# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _should_remap_transformers_class(class_name) or class_name

class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}

Expand Down Expand Up @@ -390,6 +396,11 @@ def simple_get_class_obj(library_name, class_name):
class_obj = getattr(pipeline_module, class_name)
else:
library = importlib.import_module(library_name)

# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _should_remap_transformers_class(class_name) or class_name

class_obj = getattr(library, class_name)

return class_obj
Expand All @@ -416,6 +427,10 @@ def get_class_obj_and_candidates(
# else we just import it from the library.
library = importlib.import_module(library_name)

# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _should_remap_transformers_class(class_name) or class_name

class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from .deprecation_utils import deprecate
from .deprecation_utils import _should_remap_transformers_class, deprecate
from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
Expand Down
42 changes: 42 additions & 0 deletions src/diffusers/utils/deprecation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,48 @@
from packaging import version


# Mapping for deprecated Transformers classes to their replacements
# This is used to handle models that reference deprecated class names in their configs
# Reference: https://github.com/huggingface/transformers/issues/40822
# Format: {
# "DeprecatedClassName": {
# "new_class": "NewClassName",
# "transformers_version": (">=", "5.0.0"), # (operation, version) tuple
# }
# }
_TRANSFORMERS_CLASS_REMAPPING = {
"CLIPFeatureExtractor": {
"new_class": "CLIPImageProcessor",
"transformers_version": (">", "4.57.0"),
},
}


def _should_remap_transformers_class(class_name: str) -> Optional[str]:
"""
Check if a Transformers class should be remapped to a newer version.

Args:
class_name: The name of the class to check

Returns:
The new class name if remapping should occur, None otherwise
"""
if class_name not in _TRANSFORMERS_CLASS_REMAPPING:
return None

from .import_utils import is_transformers_version

mapping = _TRANSFORMERS_CLASS_REMAPPING[class_name]
operation, required_version = mapping["transformers_version"]

# Only remap if the transformers version meets the requirement
if is_transformers_version(operation, required_version):
return mapping["new_class"]

return None


def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
from .. import __version__

Expand Down
Loading