Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import fnmatch
import importlib
import inspect
Expand Down Expand Up @@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
expected_types = pipeline_class._get_signature_types()
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
Expand All @@ -832,6 +834,22 @@ def load_module(name, value):
return True

init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
scheduler_types = expected_types["scheduler"][0]
if isinstance(scheduler_types, enum.EnumType):
scheduler_types = list(scheduler_types)
else:
scheduler_types = [str(scheduler_types)]
scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types]

for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
Copy link
Collaborator

@DN6 DN6 Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have the types extracted in expected_types can't we fetch them using the key and then check if the passed object is an instance of the type? If the expected type is an enum then we can check if the passed obj class name exists in the keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@DN6 DN6 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to make this check more agnostic to the component names.

We have a few pipelines with Union types on non-scheduler components (mostly AnimateDiff). So this snippet would fail even though it's valid, because init_dict is based on the model_index.json which doesn't support multiple types.

from diffusers import (
	AnimateDiffPipeline,
	UNetMotionModel,
)

unet = UNetMotionModel()
pipe = AnimateDiffPipeline.from_pretrained(
	"hf-internal-testing/tiny-sd-pipe", unet=unet
)

Enforcing scheduler types might be a breaking change cc: @yiyixuxu . e.g. Using DDIM with Kandinsky is currently valid, but with this change any downstream code doing this it would break. It would be good to enforce on the pipelines with Flow based schedulers though? (perhaps via a new Enum)

I would try something like:

        for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
            if key not in passed_class_obj:
                continue

            class_obj = passed_class_obj[key]
            _expected_class_types = []
            for expected_type in expected_types[key]:
                if isinstance(expected_type, enum.EnumMeta):
                    _expected_class_types.extend(expected_type.__members__.keys())
                else:
                    _expected_class_types.append(expected_type.__name__)

            _is_valid_type = class_obj.__class__.__name__ in _expected_class_types
            if isinstance(class_obj, SchedulerMixin) and not _is_valid_type:
                # Handle case where scheduler is still valid 
                # raise if scheduler is meant to be a Flow based scheduler?
            elif not _is_valid_type:
                raise ValueError(f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}.")

Copy link
Contributor Author

@hlky hlky Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this for scheduler

                _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types)
                _is_flow_match = "FlowMatch" in class_obj.__class__.__name__
                if _requires_flow_match and not _is_flow_match:
                    raise ValueError(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.")
                elif not _requires_flow_match and _is_flow_match:
                    raise ValueError(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need a value error here, a warning is enough, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A warning should be sufficient, it's mainly for the situation here #10093 (comment) where the wrong text encoder is given because the resulting error is uninformative.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do a warning then:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just chiming here a bit to share a perspective as a user (not a strong opinion). Related to #10189 (comment).

Here

https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_loading_utils.py#L287-L290

if there's an unexpected module passed we raise a value error. I think the check is almost along similar lines -- users are passing assigning components that are unexpected / incompatible. We probably cannot predict the consequences of allowing the loading without raising any errors but if we raise an error, users would know what to do to fix the in correct behaviour.

if key not in passed_class_obj:
continue
class_name = passed_class_obj[key].__class__.__name__
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
if key == "scheduler" and class_name not in scheduler_types:
raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.")
elif key != "scheduler" and class_name != expected_class_name:
raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.")

# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
Expand Down
11 changes: 11 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,17 @@ def test_pipe_same_device_id_offload(self):
sd.maybe_free_model_hooks()
assert sd._offload_gpu_id == 5

def test_wrong_model(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
with self.assertRaises(ValueError) as error_context:
_ = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer
)

assert "Expected" in str(error_context.exception)
assert "text_encoder" in str(error_context.exception)
assert f"{tokenizer.__class__.__name}" in str(error_context.exception)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also a check for the scheduler as that is handled slightly differently?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add it. For context this is what we're handling:

class KarrasDiffusionSchedulers(Enum):
DDIMScheduler = 1
DDPMScheduler = 2
PNDMScheduler = 3
LMSDiscreteScheduler = 4
EulerDiscreteScheduler = 5
HeunDiscreteScheduler = 6
EulerAncestralDiscreteScheduler = 7
DPMSolverMultistepScheduler = 8
DPMSolverSinglestepScheduler = 9
KDPM2DiscreteScheduler = 10
KDPM2AncestralDiscreteScheduler = 11
DEISMultistepScheduler = 12
UniPCMultistepScheduler = 13
DPMSolverSDEScheduler = 14
EDMEulerScheduler = 15

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's cool. But I don't see the flow matching schedulers here. So, if I do assign a text encoder to scheduler in an RF pipeline (FluxPipeline, for example), would it still work as expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that also works, for pipelines like Flux we're getting the type

<class 'diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler'>

For SD etc we get the enum

<enum 'KarrasDiffusionSchedulers'>
[<KarrasDiffusionSchedulers.DDIMScheduler: 1>, <KarrasDiffusionSchedulers.DDPMScheduler: 2>, <KarrasDiffusionSchedulers.PNDMScheduler: 3>, <KarrasDiffusionSchedulers.LMSDiscreteScheduler: 4>, <KarrasDiffusionSchedulers.EulerDiscreteScheduler: 5>, <KarrasDiffusionSchedulers.HeunDiscreteScheduler: 6>, <KarrasDiffusionSchedulers.EulerAncestralDiscreteScheduler: 7>, <KarrasDiffusionSchedulers.DPMSolverMultistepScheduler: 8>, <KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler: 9>, <KarrasDiffusionSchedulers.KDPM2DiscreteScheduler: 10>, <KarrasDiffusionSchedulers.KDPM2AncestralDiscreteScheduler: 11>, <KarrasDiffusionSchedulers.DEISMultistepScheduler: 12>, <KarrasDiffusionSchedulers.UniPCMultistepScheduler: 13>, <KarrasDiffusionSchedulers.DPMSolverSDEScheduler: 14>, <KarrasDiffusionSchedulers.EDMEulerScheduler: 15>]

So we apply the same processing (str, split, strip applies for type case) to get a list of scheduler

['FlowMatchEulerDiscreteScheduler']
['DDIMScheduler', 'DDPMScheduler', 'PNDMScheduler', 'LMSDiscreteScheduler', 'EulerDiscreteScheduler', 'HeunDiscreteScheduler', 'EulerAncestralDiscreteScheduler', 'DPMSolverMultistepScheduler', 'DPMSolverSinglestepScheduler', 'KDPM2DiscreteScheduler', 'KDPM2AncestralDiscreteScheduler', 'DEISMultistepScheduler', 'UniPCMultistepScheduler', 'DPMSolverSDEScheduler', 'EDMEulerScheduler']

If it's not a scheduler it will raise or if it's the wrong type of scheduler.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining! Works for me.

Copy link
Contributor Author

@hlky hlky Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now also support Union, context is failed test test_load_connected_checkpoint_with_passed_obj for KandinskyV22CombinedPipeline, we also change scheduler type to Union[DDPMScheduler, UnCLIPScheduler], the test is actually for passing obj to submodels, but changing the scheduler is how that test works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for wrong scheduler are added.



@slow
@require_torch_gpu
Expand Down
Loading