1313# limitations under the License.
1414
1515import re
16- from typing import List , Optional , Union
16+ from typing import Dict , List , Optional , Union
1717
1818import torch
1919
@@ -65,10 +65,21 @@ class GroupOffloadingHook(ModelHook):
6565 encounter such an error.
6666 """
6767
68- def __init__ (self , group : ModuleGroup , offload_on_init : bool = True , non_blocking : bool = False ) -> None :
68+ def __init__ (
69+ self ,
70+ group : ModuleGroup ,
71+ offload_on_init : bool = True ,
72+ non_blocking : bool = False ,
73+ stream : Optional [torch .cuda .Stream ] = None ,
74+ next_group : Optional [ModuleGroup ] = None ,
75+ cpu_param_dict : Optional [Dict [torch .nn .Parameter , torch .Tensor ]] = None ,
76+ ) -> None :
6977 self .group = group
7078 self .offload_on_init = offload_on_init
7179 self .non_blocking = non_blocking
80+ self .stream = stream
81+ self .next_group = next_group
82+ self .cpu_param_dict = cpu_param_dict
7283
7384 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
7485 if self .offload_on_init :
@@ -87,16 +98,34 @@ def post_forward(self, module: torch.nn.Module, output):
8798
8899 def onload_ (self , module : torch .nn .Module ) -> None :
89100 if self .group .onload_leader == module :
90- for group_module in self .group .modules :
91- group_module .to (self .group .onload_device , non_blocking = self .non_blocking )
101+ breakpoint ()
102+ if self .stream is not None :
103+ # Wait for previous Host->Device transfer to complete
104+ self .stream .synchronize ()
105+
106+ if self .next_group is None :
107+ return
108+
109+ # Start Host->Device transfer for next group
110+ with torch .cuda .stream (self .stream ):
111+ for group_module in self .next_group .modules :
112+ group_module .to (self .next_group .onload_device , non_blocking = True )
113+ else :
114+ for group_module in self .group .modules :
115+ group_module .to (self .group .onload_device , non_blocking = self .non_blocking )
92116
93117 def offload_ (self , module : torch .nn .Module ) -> None :
94118 if self .group .offload_leader == module :
95- for group_module in self .group .modules :
96- group_module .to (self .group .offload_device , non_blocking = self .non_blocking )
97- # TODO: do we need to sync here because of GPU->CPU transfer?
98- if self .non_blocking and self .group .offload_device .type == "cpu" :
99- torch .cpu .synchronize ()
119+ if self .stream is not None :
120+ for group_module in self .group .modules :
121+ for param in group_module .parameters ():
122+ param .data = self .cpu_param_dict [param ]
123+ else :
124+ for group_module in self .group .modules :
125+ group_module .to (self .group .offload_device , non_blocking = self .non_blocking )
126+ # TODO: do we need to sync here because of GPU->CPU transfer?
127+ if self .non_blocking and self .group .offload_device .type == "cpu" :
128+ torch .cpu .synchronize ()
100129
101130
102131def apply_group_offloading (
@@ -107,12 +136,22 @@ def apply_group_offloading(
107136 onload_device : torch .device = torch .device ("cuda" ),
108137 force_offload : bool = True ,
109138 non_blocking : bool = False ,
139+ cuda_stream : bool = False ,
110140) -> None :
141+ stream = None
142+ if cuda_stream :
143+ stream = torch .cuda .Stream ()
111144 if offload_group_patterns == "diffusers_block" :
112145 if num_blocks_per_group is None :
113146 raise ValueError ("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK." )
114147 _apply_group_offloading_diffusers_block (
115- module , num_blocks_per_group , offload_device , onload_device , force_offload , non_blocking
148+ module ,
149+ num_blocks_per_group ,
150+ offload_device ,
151+ onload_device ,
152+ force_offload ,
153+ non_blocking ,
154+ stream ,
116155 )
117156 else :
118157 _apply_group_offloading_group_patterns (
@@ -127,7 +166,14 @@ def _apply_group_offloading_diffusers_block(
127166 onload_device : torch .device ,
128167 force_offload : bool ,
129168 non_blocking : bool ,
169+ stream : Optional [torch .cuda .Stream ] = None ,
130170) -> None :
171+ cpu_param_dict = None
172+ if stream is not None :
173+ for param in module .parameters ():
174+ param .data = param .data .cpu ().pin_memory ()
175+ cpu_param_dict = {param : param .data for param in module .parameters ()}
176+
131177 # Handle device offloading/onloading for unet/transformer stack modules
132178 for stack_identifier in _COMMON_STACK_IDENTIFIERS :
133179 if not hasattr (module , stack_identifier ) or not isinstance (
@@ -137,14 +183,29 @@ def _apply_group_offloading_diffusers_block(
137183
138184 stack = getattr (module , stack_identifier )
139185 num_blocks = len (stack )
186+ module_groups = []
140187
141188 for i in range (0 , num_blocks , num_blocks_per_group ):
142189 blocks = stack [i : i + num_blocks_per_group ]
143190 group = ModuleGroup (
144191 blocks , offload_device , onload_device , offload_leader = blocks [- 1 ], onload_leader = blocks [0 ]
145192 )
193+ module_groups .append (group )
194+
195+ for i , group in enumerate (module_groups ):
196+ next_group = module_groups [i + 1 ] if i + 1 < len (module_groups ) and stream is not None else None
146197 should_offload = force_offload or i > 0
147- _apply_group_offloading (group , should_offload , non_blocking )
198+ _apply_group_offloading (group , should_offload , non_blocking , stream , next_group , cpu_param_dict )
199+
200+ if stream is not None :
201+ # Start Host->Device transfer for the first group
202+ with torch .cuda .stream (stream ):
203+ for group_module in module_groups [0 ].modules :
204+ group_module .to (onload_device , non_blocking = True )
205+ if len (module_groups ) > 1 :
206+ # Assign the first module_group as the next_group for the last module_group
207+ hook_registry = HookRegistry .check_if_exists_or_initialize (module_groups [- 1 ].onload_leader )
208+ hook_registry .hooks ["group_offloading" ].next_group = module_groups [0 ]
148209
149210 # Handle device offloading/onloading for non-stack modules
150211 for name , submodule in module .named_modules ():
@@ -154,7 +215,6 @@ def _apply_group_offloading_diffusers_block(
154215 # for enabling offloading.
155216 continue
156217 layer_name = name_split [0 ]
157- print (layer_name )
158218 if layer_name in _COMMON_STACK_IDENTIFIERS :
159219 continue
160220 group = ModuleGroup (
@@ -211,8 +271,15 @@ def _apply_group_offloading_group_patterns(
211271 _apply_group_offloading (group , force_offload , non_blocking )
212272
213273
214- def _apply_group_offloading (group : ModuleGroup , offload_on_init : bool , non_blocking : bool ) -> None :
274+ def _apply_group_offloading (
275+ group : ModuleGroup ,
276+ offload_on_init : bool ,
277+ non_blocking : bool ,
278+ stream : Optional [torch .cuda .Stream ] = None ,
279+ next_group : Optional [ModuleGroup ] = None ,
280+ cpu_param_dict : Optional [Dict [torch .nn .Parameter , torch .Tensor ]] = None ,
281+ ) -> None :
215282 for module in group .modules :
216- hook = GroupOffloadingHook (group , offload_on_init , non_blocking )
283+ hook = GroupOffloadingHook (group , offload_on_init , non_blocking , stream , next_group , cpu_param_dict )
217284 registry = HookRegistry .check_if_exists_or_initialize (module )
218285 registry .register_hook (hook , "group_offloading" )
0 commit comments