|
19 | 19 | import os |
20 | 20 | import re |
21 | 21 | import sys |
| 22 | +from collections import OrderedDict |
22 | 23 | from dataclasses import dataclass |
23 | 24 | from pathlib import Path |
24 | 25 | from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin |
@@ -1172,6 +1173,93 @@ def reset_device_map(self): |
1172 | 1173 | component.to("cpu") |
1173 | 1174 | self.hf_device_map = None |
1174 | 1175 |
|
| 1176 | + def enable_dynamic_upcasting( |
| 1177 | + self, |
| 1178 | + components: Optional[List[str]] = None, |
| 1179 | + upcast_dtype: Optional[torch.dtype] = None, |
| 1180 | + ): |
| 1181 | + r""" |
| 1182 | + Enable module-wise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g. |
| 1183 | + torch.float8_e4m3fn, but perform inference using a dtype that is supported on the GPU, by casting the module to |
| 1184 | + the appropriate dtype right before the foward pass. The module is then moved back to the low memory dtype after |
| 1185 | + the foward pass. |
| 1186 | +
|
| 1187 | + """ |
| 1188 | + if components is None: |
| 1189 | + raise ValueError("Please provide a list of pipeline component names to apply dynamic upcasting") |
| 1190 | + |
| 1191 | + def fn_recursive_upcast(module, dtype, original_dtype, keep_in_fp32_modules): |
| 1192 | + has_children = list(module.children()) |
| 1193 | + upcast_dtype = dtype |
| 1194 | + downcast_dtype = original_dtype |
| 1195 | + |
| 1196 | + def upcast_hook_fn(module, inputs): |
| 1197 | + module = module.to(upcast_dtype) |
| 1198 | + |
| 1199 | + def downcast_hook_fn(module, *args, **kwargs): |
| 1200 | + module = module.to(downcast_dtype) |
| 1201 | + |
| 1202 | + if not has_children: |
| 1203 | + module.register_forward_pre_hook(upcast_hook_fn) |
| 1204 | + module.register_forward_hook(downcast_hook_fn) |
| 1205 | + |
| 1206 | + for name, child in module.named_children(): |
| 1207 | + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): |
| 1208 | + dtype = torch.float32 |
| 1209 | + else: |
| 1210 | + dtype = upcast_dtype |
| 1211 | + |
| 1212 | + fn_recursive_upcast(child, dtype, original_dtype, keep_in_fp32_modules) |
| 1213 | + |
| 1214 | + for component in components: |
| 1215 | + if not hasattr(self, component): |
| 1216 | + raise ValueError(f"Pipeline has no component named: {component}") |
| 1217 | + |
| 1218 | + component_module = getattr(self, component) |
| 1219 | + if not isinstance(component_module, torch.nn.Module): |
| 1220 | + raise ValueError( |
| 1221 | + f"Pipeline component: {component} is not a torch.nn.Module. Cannot apply dynamic upcasting." |
| 1222 | + ) |
| 1223 | + |
| 1224 | + use_keep_in_fp32_modules = ( |
| 1225 | + hasattr(component_module, "_keep_in_fp32_modules") |
| 1226 | + and (component_module._keep_in_fp32_modules is not None) |
| 1227 | + and (upcast_dtype != torch.float32) |
| 1228 | + ) |
| 1229 | + if use_keep_in_fp32_modules: |
| 1230 | + keep_in_fp32_modules = component_module._keep_in_fp32_modules |
| 1231 | + else: |
| 1232 | + keep_in_fp32_modules = [] |
| 1233 | + |
| 1234 | + original_dtype = component_module.dtype |
| 1235 | + for name, module in component_module.named_children(): |
| 1236 | + fn_recursive_upcast(module, upcast_dtype, original_dtype, keep_in_fp32_modules) |
| 1237 | + |
| 1238 | + def disable_dynamic_upcasting( |
| 1239 | + self, |
| 1240 | + ): |
| 1241 | + def fn_recursive_upcast(module): |
| 1242 | + has_children = list(module.children()) |
| 1243 | + if not has_children: |
| 1244 | + module._forward_pre_hooks = OrderedDict() |
| 1245 | + module._forward_hooks = OrderedDict() |
| 1246 | + |
| 1247 | + for child in module.children(): |
| 1248 | + fn_recursive_upcast(child) |
| 1249 | + |
| 1250 | + for component in self.components: |
| 1251 | + if not hasattr(self, component): |
| 1252 | + raise ValueError(f"Pipeline has no component named: {component}") |
| 1253 | + |
| 1254 | + component_module = getattr(self, component) |
| 1255 | + if not issubclass(component_module, torch.nn.Module): |
| 1256 | + raise ValueError( |
| 1257 | + f"Pipeline component: {component} is not an torch.nn.Module. Cannot apply dynamic upcasting." |
| 1258 | + ) |
| 1259 | + |
| 1260 | + for module in component_module.children(): |
| 1261 | + fn_recursive_upcast(module) |
| 1262 | + |
1175 | 1263 | @classmethod |
1176 | 1264 | @validate_hf_hub_args |
1177 | 1265 | def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: |
|
0 commit comments