Skip to content

Commit 6be43b8

Browse files
committed
handle .cuda()
1 parent a872e84 commit 6be43b8

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
12291229
# Adapted from `transformers`.
12301230
@wraps(torch.nn.Module.cuda)
12311231
def cuda(self, *args, **kwargs):
1232+
from ..hooks.group_offloading import _is_group_offload_enabled
1233+
12321234
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
12331235
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
12341236
if getattr(self, "is_loaded_in_8bit", False):
@@ -1241,6 +1243,14 @@ def cuda(self, *args, **kwargs):
12411243
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
12421244
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
12431245
)
1246+
1247+
# Checks if group offloading is enabled
1248+
if _is_group_offload_enabled(self):
1249+
logger.warning(
1250+
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.cuda()` is not supported."
1251+
)
1252+
return self
1253+
12441254
return super().cuda(*args, **kwargs)
12451255

12461256
# Adapted from `transformers`.

0 commit comments

Comments
 (0)