Skip to content

Commit be98308

Browse files
committed
cuda stream prefetch
1 parent c426a34 commit be98308

File tree

2 files changed

+82
-14
lines changed

2 files changed

+82
-14
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import re
16-
from typing import List, Optional, Union
16+
from typing import Dict, List, Optional, Union
1717

1818
import torch
1919

@@ -65,10 +65,21 @@ class GroupOffloadingHook(ModelHook):
6565
encounter such an error.
6666
"""
6767

68-
def __init__(self, group: ModuleGroup, offload_on_init: bool = True, non_blocking: bool = False) -> None:
68+
def __init__(
69+
self,
70+
group: ModuleGroup,
71+
offload_on_init: bool = True,
72+
non_blocking: bool = False,
73+
stream: Optional[torch.cuda.Stream] = None,
74+
next_group: Optional[ModuleGroup] = None,
75+
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
76+
) -> None:
6977
self.group = group
7078
self.offload_on_init = offload_on_init
7179
self.non_blocking = non_blocking
80+
self.stream = stream
81+
self.next_group = next_group
82+
self.cpu_param_dict = cpu_param_dict
7283

7384
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
7485
if self.offload_on_init:
@@ -87,16 +98,34 @@ def post_forward(self, module: torch.nn.Module, output):
8798

8899
def onload_(self, module: torch.nn.Module) -> None:
89100
if self.group.onload_leader == module:
90-
for group_module in self.group.modules:
91-
group_module.to(self.group.onload_device, non_blocking=self.non_blocking)
101+
breakpoint()
102+
if self.stream is not None:
103+
# Wait for previous Host->Device transfer to complete
104+
self.stream.synchronize()
105+
106+
if self.next_group is None:
107+
return
108+
109+
# Start Host->Device transfer for next group
110+
with torch.cuda.stream(self.stream):
111+
for group_module in self.next_group.modules:
112+
group_module.to(self.next_group.onload_device, non_blocking=True)
113+
else:
114+
for group_module in self.group.modules:
115+
group_module.to(self.group.onload_device, non_blocking=self.non_blocking)
92116

93117
def offload_(self, module: torch.nn.Module) -> None:
94118
if self.group.offload_leader == module:
95-
for group_module in self.group.modules:
96-
group_module.to(self.group.offload_device, non_blocking=self.non_blocking)
97-
# TODO: do we need to sync here because of GPU->CPU transfer?
98-
if self.non_blocking and self.group.offload_device.type == "cpu":
99-
torch.cpu.synchronize()
119+
if self.stream is not None:
120+
for group_module in self.group.modules:
121+
for param in group_module.parameters():
122+
param.data = self.cpu_param_dict[param]
123+
else:
124+
for group_module in self.group.modules:
125+
group_module.to(self.group.offload_device, non_blocking=self.non_blocking)
126+
# TODO: do we need to sync here because of GPU->CPU transfer?
127+
if self.non_blocking and self.group.offload_device.type == "cpu":
128+
torch.cpu.synchronize()
100129

101130

102131
def apply_group_offloading(
@@ -107,12 +136,22 @@ def apply_group_offloading(
107136
onload_device: torch.device = torch.device("cuda"),
108137
force_offload: bool = True,
109138
non_blocking: bool = False,
139+
cuda_stream: bool = False,
110140
) -> None:
141+
stream = None
142+
if cuda_stream:
143+
stream = torch.cuda.Stream()
111144
if offload_group_patterns == "diffusers_block":
112145
if num_blocks_per_group is None:
113146
raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.")
114147
_apply_group_offloading_diffusers_block(
115-
module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking
148+
module,
149+
num_blocks_per_group,
150+
offload_device,
151+
onload_device,
152+
force_offload,
153+
non_blocking,
154+
stream,
116155
)
117156
else:
118157
_apply_group_offloading_group_patterns(
@@ -127,7 +166,14 @@ def _apply_group_offloading_diffusers_block(
127166
onload_device: torch.device,
128167
force_offload: bool,
129168
non_blocking: bool,
169+
stream: Optional[torch.cuda.Stream] = None,
130170
) -> None:
171+
cpu_param_dict = None
172+
if stream is not None:
173+
for param in module.parameters():
174+
param.data = param.data.cpu().pin_memory()
175+
cpu_param_dict = {param: param.data for param in module.parameters()}
176+
131177
# Handle device offloading/onloading for unet/transformer stack modules
132178
for stack_identifier in _COMMON_STACK_IDENTIFIERS:
133179
if not hasattr(module, stack_identifier) or not isinstance(
@@ -137,14 +183,29 @@ def _apply_group_offloading_diffusers_block(
137183

138184
stack = getattr(module, stack_identifier)
139185
num_blocks = len(stack)
186+
module_groups = []
140187

141188
for i in range(0, num_blocks, num_blocks_per_group):
142189
blocks = stack[i : i + num_blocks_per_group]
143190
group = ModuleGroup(
144191
blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0]
145192
)
193+
module_groups.append(group)
194+
195+
for i, group in enumerate(module_groups):
196+
next_group = module_groups[i + 1] if i + 1 < len(module_groups) and stream is not None else None
146197
should_offload = force_offload or i > 0
147-
_apply_group_offloading(group, should_offload, non_blocking)
198+
_apply_group_offloading(group, should_offload, non_blocking, stream, next_group, cpu_param_dict)
199+
200+
if stream is not None:
201+
# Start Host->Device transfer for the first group
202+
with torch.cuda.stream(stream):
203+
for group_module in module_groups[0].modules:
204+
group_module.to(onload_device, non_blocking=True)
205+
if len(module_groups) > 1:
206+
# Assign the first module_group as the next_group for the last module_group
207+
hook_registry = HookRegistry.check_if_exists_or_initialize(module_groups[-1].onload_leader)
208+
hook_registry.hooks["group_offloading"].next_group = module_groups[0]
148209

149210
# Handle device offloading/onloading for non-stack modules
150211
for name, submodule in module.named_modules():
@@ -154,7 +215,6 @@ def _apply_group_offloading_diffusers_block(
154215
# for enabling offloading.
155216
continue
156217
layer_name = name_split[0]
157-
print(layer_name)
158218
if layer_name in _COMMON_STACK_IDENTIFIERS:
159219
continue
160220
group = ModuleGroup(
@@ -211,8 +271,15 @@ def _apply_group_offloading_group_patterns(
211271
_apply_group_offloading(group, force_offload, non_blocking)
212272

213273

214-
def _apply_group_offloading(group: ModuleGroup, offload_on_init: bool, non_blocking: bool) -> None:
274+
def _apply_group_offloading(
275+
group: ModuleGroup,
276+
offload_on_init: bool,
277+
non_blocking: bool,
278+
stream: Optional[torch.cuda.Stream] = None,
279+
next_group: Optional[ModuleGroup] = None,
280+
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
281+
) -> None:
215282
for module in group.modules:
216-
hook = GroupOffloadingHook(group, offload_on_init, non_blocking)
283+
hook = GroupOffloadingHook(group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict)
217284
registry = HookRegistry.check_if_exists_or_initialize(module)
218285
registry.register_hook(hook, "group_offloading")

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def forward(
240240
norm_hidden_states = self.norm1(hidden_states)
241241

242242
num_ada_params = self.scale_shift_table.shape[0]
243+
breakpoint()
243244
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
244245
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
245246
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa

0 commit comments

Comments
 (0)