diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index d33b80dba091..7a98fa3da14a 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -339,7 +339,8 @@ def offload_models( original_devices = [next(m.parameters()).device for m in modules] else: assert len(modules) == 1 - original_devices = modules[0].device + # For DiffusionPipeline, wrap the device in a list to make it iterable + original_devices = [modules[0].device] # move to target device for m in modules: m.to(device)