Skip to content

Commit 56ac8b4

Browse files
committed
import FlaxSchedulerMixin
1 parent 87f8f03 commit 56ac8b4

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
4848
from ..quantizers.bitsandbytes.utils import _check_bnb_status
4949
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
50+
from ..schedulers.scheduling_utils_flax import FlaxSchedulerMixin
5051
from ..utils import (
5152
CONFIG_NAME,
5253
DEPRECATED_REVISION_ARGS,
@@ -848,7 +849,9 @@ def load_module(name, value):
848849
_expected_class_types.append(expected_type.__name__)
849850

850851
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
851-
if (isinstance(class_obj, SchedulerMixin) or isinstance(class_obj, FlaxSchedulerMixin)) and not _is_valid_type:
852+
if (
853+
isinstance(class_obj, SchedulerMixin) or isinstance(class_obj, FlaxSchedulerMixin)
854+
) and not _is_valid_type:
852855
_requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types)
853856
_is_flow_match = "FlowMatch" in class_obj.__class__.__name__
854857
if _requires_flow_match and not _is_flow_match:

0 commit comments

Comments
 (0)