Skip to content

Commit fe447ba

Browse files
committed
address review comments
1 parent 9ec70f0 commit fe447ba

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from ..models.attention_processor import FusedAttnProcessor2_0
4646
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
4747
from ..quantizers.bitsandbytes.utils import _check_bnb_status
48+
from ..quantizers.torchao.utils import _check_torchao_status
4849
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
4950
from ..utils import (
5051
CONFIG_NAME,
@@ -388,6 +389,7 @@ def to(self, *args, **kwargs):
388389

389390
device = device or device_arg
390391
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
392+
pipeline_has_torchao = any(_check_torchao_status(module) for _, module in self.components.items())
391393

392394
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
393395
def module_is_sequentially_offloaded(module):
@@ -411,7 +413,7 @@ def module_is_offloaded(module):
411413
module_is_sequentially_offloaded(module) for _, module in self.components.items()
412414
)
413415
if device and torch.device(device).type == "cuda":
414-
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
416+
if pipeline_is_sequentially_offloaded and not (pipeline_has_bnb or pipeline_has_torchao):
415417
raise ValueError(
416418
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
417419
)
@@ -420,6 +422,12 @@ def module_is_offloaded(module):
420422
raise ValueError(
421423
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
422424
)
425+
elif pipeline_has_torchao:
426+
raise ValueError(
427+
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `torchao`. This is not supported. There are two options on what could be done to fix this error:\n"
428+
"1. Move the individual components of the model to the desired device directly using `.to()` on each.\n"
429+
'2. Pass `device_map="balanced"` when initializing the pipeline to let `accelerate` handle the device placement.'
430+
)
423431

424432
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
425433
if is_pipeline_device_mapped:
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ..quantization_config import QuantizationMethod
16+
17+
18+
def _check_torchao_status(module) -> bool:
19+
is_loaded_in_torchao = getattr(module, "quantization_method", None) == QuantizationMethod.TORCHAO
20+
return is_loaded_in_torchao

0 commit comments

Comments
 (0)