2020from .hooks import HookRegistry , ModelHook
2121
2222
23- _TRANSFORMER_STACK_IDENTIFIERS = [
23+ _COMMON_STACK_IDENTIFIERS = {
2424 "transformer_blocks" ,
2525 "single_transformer_blocks" ,
2626 "temporal_transformer_blocks" ,
2727 "transformer_layers" ,
2828 "layers" ,
2929 "blocks" ,
30- ]
30+ "down_blocks" ,
31+ "up_blocks" ,
32+ "mid_blocks" ,
33+ }
3134
3235
3336class ModuleGroup :
@@ -62,25 +65,16 @@ class GroupOffloadingHook(ModelHook):
6265 encounter such an error.
6366 """
6467
65- def __init__ (self , group : ModuleGroup , offload_on_init : bool = True ) -> None :
68+ def __init__ (self , group : ModuleGroup , offload_on_init : bool = True , non_blocking : bool = False ) -> None :
6669 self .group = group
6770 self .offload_on_init = offload_on_init
71+ self .non_blocking = non_blocking
6872
6973 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
7074 if self .offload_on_init :
7175 self .offload_ (module )
7276 return module
7377
74- def onload_ (self , module : torch .nn .Module ) -> None :
75- if self .group .onload_leader is not None and self .group .onload_leader == module :
76- for group_module in self .group .modules :
77- group_module .to (self .group .onload_device )
78-
79- def offload_ (self , module : torch .nn .Module ) -> None :
80- if self .group .offload_leader == module :
81- for group_module in self .group .modules :
82- group_module .to (self .group .offload_device )
83-
8478 def pre_forward (self , module : torch .nn .Module , * args , ** kwargs ):
8579 if self .group .onload_leader is None :
8680 self .group .onload_leader = module
@@ -91,6 +85,19 @@ def post_forward(self, module: torch.nn.Module, output):
9185 self .offload_ (module )
9286 return output
9387
88+ def onload_ (self , module : torch .nn .Module ) -> None :
89+ if self .group .onload_leader == module :
90+ for group_module in self .group .modules :
91+ group_module .to (self .group .onload_device , non_blocking = self .non_blocking )
92+
93+ def offload_ (self , module : torch .nn .Module ) -> None :
94+ if self .group .offload_leader == module :
95+ for group_module in self .group .modules :
96+ group_module .to (self .group .offload_device , non_blocking = self .non_blocking )
97+ # TODO: do we need to sync here because of GPU->CPU transfer?
98+ if self .non_blocking and self .group .offload_device .type == "cpu" :
99+ torch .cpu .synchronize ()
100+
94101
95102def apply_group_offloading (
96103 module : torch .nn .Module ,
@@ -99,14 +106,17 @@ def apply_group_offloading(
99106 offload_device : torch .device = torch .device ("cpu" ),
100107 onload_device : torch .device = torch .device ("cuda" ),
101108 force_offload : bool = True ,
109+ non_blocking : bool = False ,
102110) -> None :
103111 if offload_group_patterns == "diffusers_block" :
112+ if num_blocks_per_group is None :
113+ raise ValueError ("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK." )
104114 _apply_group_offloading_diffusers_block (
105- module , num_blocks_per_group , offload_device , onload_device , force_offload
115+ module , num_blocks_per_group , offload_device , onload_device , force_offload , non_blocking
106116 )
107117 else :
108118 _apply_group_offloading_group_patterns (
109- module , offload_group_patterns , offload_device , onload_device , force_offload
119+ module , offload_group_patterns , offload_device , onload_device , force_offload , non_blocking
110120 )
111121
112122
@@ -116,26 +126,47 @@ def _apply_group_offloading_diffusers_block(
116126 offload_device : torch .device ,
117127 onload_device : torch .device ,
118128 force_offload : bool ,
129+ non_blocking : bool ,
119130) -> None :
120- if num_blocks_per_group is None :
121- raise ValueError ("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK." )
122-
123- for transformer_stack_identifier in _TRANSFORMER_STACK_IDENTIFIERS :
124- if not hasattr (module , transformer_stack_identifier ) or not isinstance (
125- getattr (module , transformer_stack_identifier ), torch .nn .ModuleList
131+ # Handle device offloading/onloading for unet/transformer stack modules
132+ for stack_identifier in _COMMON_STACK_IDENTIFIERS :
133+ if not hasattr (module , stack_identifier ) or not isinstance (
134+ getattr (module , stack_identifier ), torch .nn .ModuleList
126135 ):
127136 continue
128137
129- transformer_stack = getattr (module , transformer_stack_identifier )
130- num_blocks = len (transformer_stack )
138+ stack = getattr (module , stack_identifier )
139+ num_blocks = len (stack )
131140
132141 for i in range (0 , num_blocks , num_blocks_per_group ):
133- blocks = transformer_stack [i : i + num_blocks_per_group ]
142+ blocks = stack [i : i + num_blocks_per_group ]
134143 group = ModuleGroup (
135144 blocks , offload_device , onload_device , offload_leader = blocks [- 1 ], onload_leader = blocks [0 ]
136145 )
137146 should_offload = force_offload or i > 0
138- _apply_group_offloading (group , should_offload )
147+ _apply_group_offloading (group , should_offload , non_blocking )
148+
149+ # Handle device offloading/onloading for non-stack modules
150+ for name , submodule in module .named_modules ():
151+ name_split = name .split ("." )
152+ if not isinstance (submodule , torch .nn .Module ) or name == "" or len (name_split ) > 1 :
153+ # We only want the layers that are top-level in the module (encompass all the submodules)
154+ # for enabling offloading.
155+ continue
156+ layer_name = name_split [0 ]
157+ print (layer_name )
158+ if layer_name in _COMMON_STACK_IDENTIFIERS :
159+ continue
160+ group = ModuleGroup (
161+ [submodule ], offload_device , onload_device , offload_leader = submodule , onload_leader = submodule
162+ )
163+ _apply_group_offloading (group , force_offload , non_blocking )
164+
165+ # Always keep parameters and buffers on onload_device
166+ for name , param in module .named_parameters (recurse = False ):
167+ param .data = param .data .to (onload_device )
168+ for name , buffer in module .named_buffers (recurse = False ):
169+ buffer .data = buffer .data .to (onload_device )
139170
140171
141172def _apply_group_offloading_group_patterns (
@@ -144,6 +175,7 @@ def _apply_group_offloading_group_patterns(
144175 offload_device : torch .device ,
145176 onload_device : torch .device ,
146177 force_offload : bool ,
178+ non_blocking : bool ,
147179) -> None :
148180 per_group_modules = []
149181 for i , offload_group_pattern in enumerate (offload_group_patterns ):
@@ -174,11 +206,11 @@ def _apply_group_offloading_group_patterns(
174206 for group in per_group_modules :
175207 # TODO: handle offload leader correctly
176208 group = ModuleGroup (group ["modules" ], offload_device , onload_device , offload_leader = group ["modules" ][- 1 ])
177- _apply_group_offloading (group , force_offload )
209+ _apply_group_offloading (group , force_offload , non_blocking )
178210
179211
180- def _apply_group_offloading (group : ModuleGroup , offload_on_init ) -> None :
212+ def _apply_group_offloading (group : ModuleGroup , offload_on_init : bool , non_blocking : bool ) -> None :
181213 for module in group .modules :
182- hook = GroupOffloadingHook (group , offload_on_init = offload_on_init )
214+ hook = GroupOffloadingHook (group , offload_on_init , non_blocking )
183215 registry = HookRegistry .check_if_exists_or_initialize (module )
184216 registry .register_hook (hook , "group_offloading" )
0 commit comments