Skip to content

Commit be55fa6

Browse files
committed
update
1 parent 413ca29 commit be55fa6

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,41 @@ def disable_xformers_memory_efficient_attention(self) -> None:
263263
"""
264264
self.set_use_memory_efficient_attention_xformers(False)
265265

266+
def enable_dynamic_upcasting(self, upcast_dtype=None):
267+
upcast_dtype = upcast_dtype or torch.float32
268+
downcast_dtype = self.dtype
269+
270+
def upcast_hook_fn(module):
271+
module = module.to(upcast_dtype)
272+
273+
def downcast_hook_fn(module):
274+
module = module.to(downcast_dtype)
275+
276+
def fn_recursive_upcast(module):
277+
has_children = list(module.children())
278+
if not has_children:
279+
module.register_forward_pre_hook(upcast_hook_fn)
280+
module.register_forward_hook(downcast_hook_fn)
281+
282+
for child in module.children():
283+
fn_recursive_upcast(child)
284+
285+
for module in self.children():
286+
fn_recursive_upcast(module)
287+
288+
def disable_dynamic_upcasting(self):
289+
def fn_recursive_upcast(module):
290+
has_children = list(module.children())
291+
if not has_children:
292+
module._forward_pre_hooks = OrderedDict()
293+
module._forward_hooks = OrderedDict()
294+
295+
for child in module.children():
296+
fn_recursive_upcast(child)
297+
298+
for module in self.children():
299+
fn_recursive_upcast(module)
300+
266301
def save_pretrained(
267302
self,
268303
save_directory: Union[str, os.PathLike],

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import re
2121
import sys
22+
from collections import OrderedDict
2223
from dataclasses import dataclass
2324
from pathlib import Path
2425
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
@@ -1172,6 +1173,93 @@ def reset_device_map(self):
11721173
component.to("cpu")
11731174
self.hf_device_map = None
11741175

1176+
def enable_dynamic_upcasting(
1177+
self,
1178+
components: Optional[List[str]] = None,
1179+
upcast_dtype: Optional[torch.dtype] = None,
1180+
):
1181+
r"""
1182+
Enable module-wise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g.
1183+
torch.float8_e4m3fn, but perform inference using a dtype that is supported on the GPU, by casting the module to
1184+
the appropriate dtype right before the foward pass. The module is then moved back to the low memory dtype after
1185+
the foward pass.
1186+
1187+
"""
1188+
if components is None:
1189+
raise ValueError("Please provide a list of pipeline component names to apply dynamic upcasting")
1190+
1191+
def fn_recursive_upcast(module, dtype, original_dtype, keep_in_fp32_modules):
1192+
has_children = list(module.children())
1193+
upcast_dtype = dtype
1194+
downcast_dtype = original_dtype
1195+
1196+
def upcast_hook_fn(module, inputs):
1197+
module = module.to(upcast_dtype)
1198+
1199+
def downcast_hook_fn(module, *args, **kwargs):
1200+
module = module.to(downcast_dtype)
1201+
1202+
if not has_children:
1203+
module.register_forward_pre_hook(upcast_hook_fn)
1204+
module.register_forward_hook(downcast_hook_fn)
1205+
1206+
for name, child in module.named_children():
1207+
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
1208+
dtype = torch.float32
1209+
else:
1210+
dtype = upcast_dtype
1211+
1212+
fn_recursive_upcast(child, dtype, original_dtype, keep_in_fp32_modules)
1213+
1214+
for component in components:
1215+
if not hasattr(self, component):
1216+
raise ValueError(f"Pipeline has no component named: {component}")
1217+
1218+
component_module = getattr(self, component)
1219+
if not isinstance(component_module, torch.nn.Module):
1220+
raise ValueError(
1221+
f"Pipeline component: {component} is not a torch.nn.Module. Cannot apply dynamic upcasting."
1222+
)
1223+
1224+
use_keep_in_fp32_modules = (
1225+
hasattr(component_module, "_keep_in_fp32_modules")
1226+
and (component_module._keep_in_fp32_modules is not None)
1227+
and (upcast_dtype != torch.float32)
1228+
)
1229+
if use_keep_in_fp32_modules:
1230+
keep_in_fp32_modules = component_module._keep_in_fp32_modules
1231+
else:
1232+
keep_in_fp32_modules = []
1233+
1234+
original_dtype = component_module.dtype
1235+
for name, module in component_module.named_children():
1236+
fn_recursive_upcast(module, upcast_dtype, original_dtype, keep_in_fp32_modules)
1237+
1238+
def disable_dynamic_upcasting(
1239+
self,
1240+
):
1241+
def fn_recursive_upcast(module):
1242+
has_children = list(module.children())
1243+
if not has_children:
1244+
module._forward_pre_hooks = OrderedDict()
1245+
module._forward_hooks = OrderedDict()
1246+
1247+
for child in module.children():
1248+
fn_recursive_upcast(child)
1249+
1250+
for component in self.components:
1251+
if not hasattr(self, component):
1252+
raise ValueError(f"Pipeline has no component named: {component}")
1253+
1254+
component_module = getattr(self, component)
1255+
if not issubclass(component_module, torch.nn.Module):
1256+
raise ValueError(
1257+
f"Pipeline component: {component} is not an torch.nn.Module. Cannot apply dynamic upcasting."
1258+
)
1259+
1260+
for module in component_module.children():
1261+
fn_recursive_upcast(module)
1262+
11751263
@classmethod
11761264
@validate_hf_hub_args
11771265
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:

0 commit comments

Comments
 (0)