Skip to content

Commit 05276c4

Browse files
committed
revert changes to pipeline utils
1 parent e3ff590 commit 05276c4

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import enum
1617
import fnmatch
1718
import importlib
1819
import inspect
@@ -45,7 +46,6 @@
4546
from ..models.attention_processor import FusedAttnProcessor2_0
4647
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
4748
from ..quantizers.bitsandbytes.utils import _check_bnb_status
48-
from ..quantizers.torchao.utils import _check_torchao_status
4949
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
5050
from ..utils import (
5151
CONFIG_NAME,
@@ -389,7 +389,6 @@ def to(self, *args, **kwargs):
389389

390390
device = device or device_arg
391391
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
392-
pipeline_has_torchao = any(_check_torchao_status(module) for _, module in self.components.items())
393392

394393
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
395394
def module_is_sequentially_offloaded(module):
@@ -413,7 +412,7 @@ def module_is_offloaded(module):
413412
module_is_sequentially_offloaded(module) for _, module in self.components.items()
414413
)
415414
if device and torch.device(device).type == "cuda":
416-
if pipeline_is_sequentially_offloaded and not (pipeline_has_bnb or pipeline_has_torchao):
415+
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
417416
raise ValueError(
418417
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
419418
)
@@ -422,12 +421,6 @@ def module_is_offloaded(module):
422421
raise ValueError(
423422
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
424423
)
425-
elif pipeline_has_torchao:
426-
raise ValueError(
427-
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `torchao`. This is not supported. There are two options on what could be done to fix this error:\n"
428-
"1. Move the individual components of the model to the desired device directly using `.to()` on each.\n"
429-
'2. Pass `device_map="balanced"` when initializing the pipeline to let `accelerate` handle the device placement.'
430-
)
431424

432425
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
433426
if is_pipeline_device_mapped:
@@ -819,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
819812
# in this case they are already instantiated in `kwargs`
820813
# extract them here
821814
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
815+
expected_types = pipeline_class._get_signature_types()
822816
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
823817
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
824818
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
@@ -841,6 +835,26 @@ def load_module(name, value):
841835

842836
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
843837

838+
for key in init_dict.keys():
839+
if key not in passed_class_obj:
840+
continue
841+
if "scheduler" in key:
842+
continue
843+
844+
class_obj = passed_class_obj[key]
845+
_expected_class_types = []
846+
for expected_type in expected_types[key]:
847+
if isinstance(expected_type, enum.EnumMeta):
848+
_expected_class_types.extend(expected_type.__members__.keys())
849+
else:
850+
_expected_class_types.append(expected_type.__name__)
851+
852+
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
853+
if not _is_valid_type:
854+
logger.warning(
855+
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
856+
)
857+
844858
# Special case: safety_checker must be loaded separately when using `from_flax`
845859
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
846860
raise NotImplementedError(

0 commit comments

Comments
 (0)