Skip to content

Commit 278cbc2

Browse files
committed
updates.patch
1 parent 49ac665 commit 278cbc2

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from contextlib import contextmanager, nullcontext
1617
from typing import Dict, List, Optional, Set, Tuple, Union
17-
import os
1818

19-
import torch
2019
import safetensors.torch
20+
import torch
21+
2122
from ..utils import get_logger, is_accelerate_available
2223
from .hooks import HookRegistry, ModelHook
2324

@@ -165,9 +166,10 @@ def onload_(self):
165166
tensor_obj.data.record_stream(current_stream)
166167
else:
167168
# Load directly to the target device (synchronous)
168-
loaded_tensors = safetensors.torch.load_file(
169-
self.safetensors_file_path, device=self.onload_device
169+
onload_device = (
170+
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
170171
)
172+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
171173
for key, tensor_obj in self.key_to_tensor.items():
172174
tensor_obj.data = loaded_tensors[key]
173175
return
@@ -265,16 +267,12 @@ class GroupOffloadingHook(ModelHook):
265267

266268
_is_stateful = False
267269

268-
def __init__(
269-
self,
270-
group: ModuleGroup,
271-
next_group: Optional[ModuleGroup] = None
272-
) -> None:
270+
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
273271
self.group = group
274272
self.next_group = next_group
275273
# map param/buffer name -> file path
276-
self.param_to_path: Dict[str,str] = {}
277-
self.buffer_to_path: Dict[str,str] = {}
274+
self.param_to_path: Dict[str, str] = {}
275+
self.buffer_to_path: Dict[str, str] = {}
278276

279277
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
280278
if self.group.offload_leader == module:
@@ -516,7 +514,6 @@ def apply_group_offloading(
516514
stream = torch.Stream()
517515
else:
518516
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
519-
520517
if offload_to_disk and offload_path is None:
521518
raise ValueError("`offload_path` must be set when `offload_to_disk=True`.")
522519

@@ -899,4 +896,4 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
899896
for submodule in module.modules():
900897
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
901898
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
902-
raise ValueError("Group offloading is not enabled for the provided module.")
899+
raise ValueError("Group offloading is not enabled for the provided module.")

src/diffusers/models/modeling_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,8 @@ def enable_group_offload(
543543
onload_device: torch.device,
544544
offload_device: torch.device = torch.device("cpu"),
545545
offload_type: str = "block_level",
546+
offload_to_disk: bool = False,
547+
offload_path: Optional[str] = None,
546548
num_blocks_per_group: Optional[int] = None,
547549
non_blocking: bool = False,
548550
use_stream: bool = False,
@@ -588,15 +590,17 @@ def enable_group_offload(
588590
f"open an issue at https://github.com/huggingface/diffusers/issues."
589591
)
590592
apply_group_offloading(
591-
self,
592-
onload_device,
593-
offload_device,
594-
offload_type,
595-
num_blocks_per_group,
596-
non_blocking,
597-
use_stream,
598-
record_stream,
593+
module=self,
594+
onload_device=onload_device,
595+
offload_device=offload_device,
596+
offload_type=offload_type,
597+
num_blocks_per_group=num_blocks_per_group,
598+
non_blocking=non_blocking,
599+
use_stream=use_stream,
600+
record_stream=record_stream,
599601
low_cpu_mem_usage=low_cpu_mem_usage,
602+
offload_to_disk=offload_to_disk,
603+
offload_path=offload_path,
600604
)
601605

602606
def save_pretrained(

0 commit comments

Comments
 (0)