Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch

from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...schedulers import DDPMScheduler, UnCLIPScheduler
from ...utils import deprecate, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
Expand Down Expand Up @@ -83,7 +83,7 @@ class KandinskyV22Pipeline(DiffusionPipeline):
def __init__(
self,
unet: UNet2DConditionModel,
scheduler: DDPMScheduler,
scheduler: Union[DDPMScheduler, UnCLIPScheduler],
movq: VQModel,
):
super().__init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,13 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
def __init__(
self,
unet: UNet2DConditionModel,
scheduler: DDPMScheduler,
scheduler: Union[DDPMScheduler, UnCLIPScheduler],
movq: VQModel,
prior_prior: PriorTransformer,
prior_image_encoder: CLIPVisionModelWithProjection,
prior_text_encoder: CLIPTextModelWithProjection,
prior_tokenizer: CLIPTokenizer,
prior_scheduler: UnCLIPScheduler,
prior_scheduler: Union[DDPMScheduler, UnCLIPScheduler],
prior_image_processor: CLIPImageProcessor,
):
super().__init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection

from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler
from ...schedulers import DDPMScheduler, UnCLIPScheduler
from ...utils import (
logging,
replace_example_docstring,
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
image_encoder: CLIPVisionModelWithProjection,
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
scheduler: UnCLIPScheduler,
scheduler: Union[DDPMScheduler, UnCLIPScheduler],
image_processor: CLIPImageProcessor,
):
super().__init__()
Expand Down
24 changes: 24 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import fnmatch
import importlib
import inspect
Expand Down Expand Up @@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
expected_types = pipeline_class._get_signature_types()
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
Expand All @@ -832,6 +834,28 @@ def load_module(name, value):
return True

init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
scheduler_types = None
if "scheduler" in expected_types:
scheduler_types = []
for scheduler_type in expected_types["scheduler"]:
if isinstance(scheduler_type, enum.EnumMeta):
scheduler_types.extend(list(scheduler_type))
else:
scheduler_types.extend([str(scheduler_type)])
scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types]

for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
Copy link
Collaborator

@DN6 DN6 Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have the types extracted in expected_types can't we fetch them using the key and then check if the passed object is an instance of the type? If the expected type is an enum then we can check if the passed obj class name exists in the keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@DN6 DN6 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to make this check more agnostic to the component names.

We have a few pipelines with Union types on non-scheduler components (mostly AnimateDiff). So this snippet would fail even though it's valid, because init_dict is based on the model_index.json which doesn't support multiple types.

from diffusers import (
	AnimateDiffPipeline,
	UNetMotionModel,
)

unet = UNetMotionModel()
pipe = AnimateDiffPipeline.from_pretrained(
	"hf-internal-testing/tiny-sd-pipe", unet=unet
)

Enforcing scheduler types might be a breaking change cc: @yiyixuxu . e.g. Using DDIM with Kandinsky is currently valid, but with this change any downstream code doing this it would break. It would be good to enforce on the pipelines with Flow based schedulers though? (perhaps via a new Enum)

I would try something like:

        for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
            if key not in passed_class_obj:
                continue

            class_obj = passed_class_obj[key]
            _expected_class_types = []
            for expected_type in expected_types[key]:
                if isinstance(expected_type, enum.EnumMeta):
                    _expected_class_types.extend(expected_type.__members__.keys())
                else:
                    _expected_class_types.append(expected_type.__name__)

            _is_valid_type = class_obj.__class__.__name__ in _expected_class_types
            if isinstance(class_obj, SchedulerMixin) and not _is_valid_type:
                # Handle case where scheduler is still valid 
                # raise if scheduler is meant to be a Flow based scheduler?
            elif not _is_valid_type:
                raise ValueError(f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}.")

Copy link
Contributor Author

@hlky hlky Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this for scheduler

                _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types)
                _is_flow_match = "FlowMatch" in class_obj.__class__.__name__
                if _requires_flow_match and not _is_flow_match:
                    raise ValueError(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.")
                elif not _requires_flow_match and _is_flow_match:
                    raise ValueError(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need a value error here, a warning is enough, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A warning should be sufficient, it's mainly for the situation here #10093 (comment) where the wrong text encoder is given because the resulting error is uninformative.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do a warning then:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just chiming here a bit to share a perspective as a user (not a strong opinion). Related to #10189 (comment).

Here

https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_loading_utils.py#L287-L290

if there's an unexpected module passed we raise a value error. I think the check is almost along similar lines -- users are passing assigning components that are unexpected / incompatible. We probably cannot predict the consequences of allowing the loading without raising any errors but if we raise an error, users would know what to do to fix the in correct behaviour.

if key not in passed_class_obj:
continue
class_name = passed_class_obj[key].__class__.__name__
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
expected_class_name = (
expected_class_name[4:] if expected_class_name.startswith("Flax") else expected_class_name
)
if key == "scheduler" and scheduler_types is not None and class_name not in scheduler_types:
raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.")
elif key != "scheduler" and class_name != expected_class_name:
raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.")

# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/schedulers/scheduling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class KarrasDiffusionSchedulers(Enum):
UniPCMultistepScheduler = 13
DPMSolverSDEScheduler = 14
EDMEulerScheduler = 15
LCMScheduler = 16


AysSchedules = {
Expand Down
11 changes: 11 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,17 @@ def test_pipe_same_device_id_offload(self):
sd.maybe_free_model_hooks()
assert sd._offload_gpu_id == 5

def test_wrong_model(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
with self.assertRaises(ValueError) as error_context:
_ = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer
)

assert "Expected" in str(error_context.exception)
assert "text_encoder" in str(error_context.exception)
assert "CLIPTokenizer" in str(error_context.exception)


@slow
@require_torch_gpu
Expand Down
Loading