@@ -264,28 +264,61 @@ def disable_xformers_memory_efficient_attention(self) -> None:
264264 self .set_use_memory_efficient_attention_xformers (False )
265265
266266 def enable_layerwise_upcasting (self , upcast_dtype = None ):
267+ r"""
268+ Enable layerwise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g.
269+ torch.float8_e4m3fn, but perform inference using a dtype that is supported by the GPU, by upcasting the
270+ individual modules in the model to the appropriate dtype right before the foward pass.
271+
272+ The module is then moved back to the low memory dtype after the foward pass.
273+ """
274+
267275 upcast_dtype = upcast_dtype or torch .float32
268- downcast_dtype = self .dtype
276+ original_dtype = self .dtype
269277
270- def upcast_hook_fn (module ):
278+ def upcast_dtype_hook_fn (module , * args , ** kwargs ):
271279 module = module .to (upcast_dtype )
272280
273- def downcast_hook_fn (module ):
274- module = module .to (downcast_dtype )
281+ def cast_to_original_dtype_hook_fn (module , * args , ** kwargs ):
282+ module = module .to (original_dtype )
275283
276284 def fn_recursive_upcast (module ):
285+ """In certain cases modules will apply casting internally or reference the dtype of internal blocks.
286+
287+ e.g.
288+
289+ ```
290+ class MyModel(nn.Module):
291+ def forward(self, x):
292+ dtype = next(iter(self.blocks.parameters())).dtype
293+ x = self.blocks(x) + torch.ones(x.size()).to(dtype)
294+ ```
295+ Layerwise upcasting will not work here, since the internal blocks remain in the low memory dtype until
296+ their `forward` method is called. We need to add the upcast hook on the entire module in order for the
297+ operation to work.
298+
299+ The `_always_upcast_modules` class attribute is a list of modules within the model that we must upcast
300+ entirely, rather than layerwise.
301+
302+ """
303+ if hasattr (self , "_always_upcast_modules" ) and module .__class__ .__name__ in self ._always_upcast_modules :
304+ # Upcast entire module and exist recursion
305+ module .register_forward_pre_hook (upcast_dtype_hook_fn )
306+ module .register_forward_hook (cast_to_original_dtype_hook_fn )
307+
308+ return
309+
277310 has_children = list (module .children ())
278311 if not has_children :
279- module .register_forward_pre_hook (upcast_hook_fn )
280- module .register_forward_hook (downcast_hook_fn )
312+ module .register_forward_pre_hook (upcast_dtype_hook_fn )
313+ module .register_forward_hook (cast_to_original_dtype_hook_fn )
281314
282315 for child in module .children ():
283316 fn_recursive_upcast (child )
284317
285318 for module in self .children ():
286319 fn_recursive_upcast (module )
287320
288- def disable_dynamic_upcasting (self ):
321+ def disable_layerwise_upcasting (self ):
289322 def fn_recursive_upcast (module ):
290323 has_children = list (module .children ())
291324 if not has_children :
0 commit comments