Skip to content

Commit b6c0f20

Browse files
committed
fix memory address problem
1 parent 195926b commit b6c0f20

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import hashlib
1516
import os
1617
from contextlib import contextmanager, nullcontext
1718
from typing import Dict, List, Optional, Set, Tuple, Union
@@ -62,6 +63,7 @@ def __init__(
6263
low_cpu_mem_usage: bool = False,
6364
onload_self: bool = True,
6465
offload_to_disk_path: Optional[str] = None,
66+
_group_id: Optional[int] = None,
6567
) -> None:
6668
self.modules = modules
6769
self.offload_device = offload_device
@@ -80,7 +82,9 @@ def __init__(
8082
self._is_offloaded_to_disk = False
8183

8284
if self.offload_to_disk_path:
83-
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
85+
self._group_id = _group_id
86+
short_hash = self._compute_group_hash(self._group_id)
87+
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
8488

8589
all_tensors = []
8690
for module in self.modules:
@@ -260,6 +264,11 @@ def offload_(self):
260264
for buffer in self.buffers:
261265
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
262266

267+
def _compute_group_hash(self, group_id):
268+
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
269+
# first 16 characters for a reasonably short but unique name
270+
return hashed_id[:16]
271+
263272

264273
class GroupOffloadingHook(ModelHook):
265274
r"""
@@ -603,6 +612,9 @@ def _apply_group_offloading_block_level(
603612

604613
for i in range(0, len(submodule), num_blocks_per_group):
605614
current_modules = submodule[i : i + num_blocks_per_group]
615+
start_idx = i
616+
end_idx = i + len(current_modules) - 1
617+
group_id = f"{name}.{start_idx}_to_{end_idx}"
606618
group = ModuleGroup(
607619
modules=current_modules,
608620
offload_device=offload_device,
@@ -615,6 +627,7 @@ def _apply_group_offloading_block_level(
615627
record_stream=record_stream,
616628
low_cpu_mem_usage=low_cpu_mem_usage,
617629
onload_self=True,
630+
_group_id=group_id,
618631
)
619632
matched_module_groups.append(group)
620633
for j in range(i, i + len(current_modules)):
@@ -649,6 +662,7 @@ def _apply_group_offloading_block_level(
649662
stream=None,
650663
record_stream=False,
651664
onload_self=True,
665+
_group_id="top_level_unmatched_modules",
652666
)
653667
if stream is None:
654668
_apply_group_offloading_hook(module, unmatched_group, None)
@@ -715,6 +729,7 @@ def _apply_group_offloading_leaf_level(
715729
record_stream=record_stream,
716730
low_cpu_mem_usage=low_cpu_mem_usage,
717731
onload_self=True,
732+
_group_id=name,
718733
)
719734
_apply_group_offloading_hook(submodule, group, None)
720735
modules_with_group_offloading.add(name)
@@ -762,6 +777,7 @@ def _apply_group_offloading_leaf_level(
762777
record_stream=record_stream,
763778
low_cpu_mem_usage=low_cpu_mem_usage,
764779
onload_self=True,
780+
_group_id=name,
765781
)
766782
_apply_group_offloading_hook(parent_module, group, None)
767783

@@ -783,6 +799,7 @@ def _apply_group_offloading_leaf_level(
783799
record_stream=False,
784800
low_cpu_mem_usage=low_cpu_mem_usage,
785801
onload_self=True,
802+
name="lazy_leafs",
786803
)
787804
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
788805

0 commit comments

Comments
 (0)