|
22 | 22 | import unittest
|
23 | 23 | import unittest.mock as mock
|
24 | 24 | import uuid
|
25 |
| -from typing import Dict, List, Tuple |
| 25 | +from collections import defaultdict |
| 26 | +from typing import Dict, List, Optional, Tuple, Union |
26 | 27 |
|
27 | 28 | import numpy as np
|
28 | 29 | import requests_mock
|
29 | 30 | import torch
|
30 |
| -from accelerate.utils import compute_module_sizes |
| 31 | +import torch.nn as nn |
| 32 | +from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size |
31 | 33 | from huggingface_hub import ModelCard, delete_repo, snapshot_download
|
32 | 34 | from huggingface_hub.utils import is_jinja_available
|
33 | 35 | from parameterized import parameterized
|
@@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
|
113 | 115 | out_queue.join()
|
114 | 116 |
|
115 | 117 |
|
| 118 | +def named_persistent_module_tensors( |
| 119 | + module: nn.Module, |
| 120 | + recurse: bool = False, |
| 121 | +): |
| 122 | + """ |
| 123 | + A helper function that gathers all the tensors (parameters + persistent buffers) of a given module. |
| 124 | +
|
| 125 | + Args: |
| 126 | + module (`torch.nn.Module`): |
| 127 | + The module we want the tensors on. |
| 128 | + recurse (`bool`, *optional`, defaults to `False`): |
| 129 | + Whether or not to go look in every submodule or just return the direct parameters and buffers. |
| 130 | + """ |
| 131 | + yield from module.named_parameters(recurse=recurse) |
| 132 | + |
| 133 | + for named_buffer in module.named_buffers(recurse=recurse): |
| 134 | + name, _ = named_buffer |
| 135 | + # Get parent by splitting on dots and traversing the model |
| 136 | + parent = module |
| 137 | + if "." in name: |
| 138 | + parent_name = name.rsplit(".", 1)[0] |
| 139 | + for part in parent_name.split("."): |
| 140 | + parent = getattr(parent, part) |
| 141 | + name = name.split(".")[-1] |
| 142 | + if name not in parent._non_persistent_buffers_set: |
| 143 | + yield named_buffer |
| 144 | + |
| 145 | + |
| 146 | +def compute_module_persistent_sizes( |
| 147 | + model: nn.Module, |
| 148 | + dtype: Optional[Union[str, torch.device]] = None, |
| 149 | + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, |
| 150 | +): |
| 151 | + """ |
| 152 | + Compute the size of each submodule of a given model (parameters + persistent buffers). |
| 153 | + """ |
| 154 | + if dtype is not None: |
| 155 | + dtype = _get_proper_dtype(dtype) |
| 156 | + dtype_size = dtype_byte_size(dtype) |
| 157 | + if special_dtypes is not None: |
| 158 | + special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()} |
| 159 | + special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()} |
| 160 | + module_sizes = defaultdict(int) |
| 161 | + |
| 162 | + module_list = [] |
| 163 | + |
| 164 | + module_list = named_persistent_module_tensors(model, recurse=True) |
| 165 | + |
| 166 | + for name, tensor in module_list: |
| 167 | + if special_dtypes is not None and name in special_dtypes: |
| 168 | + size = tensor.numel() * special_dtypes_size[name] |
| 169 | + elif dtype is None: |
| 170 | + size = tensor.numel() * dtype_byte_size(tensor.dtype) |
| 171 | + elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): |
| 172 | + # According to the code in set_module_tensor_to_device, these types won't be converted |
| 173 | + # so use their original size here |
| 174 | + size = tensor.numel() * dtype_byte_size(tensor.dtype) |
| 175 | + else: |
| 176 | + size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype)) |
| 177 | + name_parts = name.split(".") |
| 178 | + for idx in range(len(name_parts) + 1): |
| 179 | + module_sizes[".".join(name_parts[:idx])] += size |
| 180 | + |
| 181 | + return module_sizes |
| 182 | + |
| 183 | + |
116 | 184 | class ModelUtilsTest(unittest.TestCase):
|
117 | 185 | def tearDown(self):
|
118 | 186 | super().tearDown()
|
@@ -1012,7 +1080,7 @@ def test_cpu_offload(self):
|
1012 | 1080 | torch.manual_seed(0)
|
1013 | 1081 | base_output = model(**inputs_dict)
|
1014 | 1082 |
|
1015 |
| - model_size = compute_module_sizes(model)[""] |
| 1083 | + model_size = compute_module_persistent_sizes(model)[""] |
1016 | 1084 | # We test several splits of sizes to make sure it works.
|
1017 | 1085 | max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
|
1018 | 1086 | with tempfile.TemporaryDirectory() as tmp_dir:
|
@@ -1042,7 +1110,7 @@ def test_disk_offload_without_safetensors(self):
|
1042 | 1110 | torch.manual_seed(0)
|
1043 | 1111 | base_output = model(**inputs_dict)
|
1044 | 1112 |
|
1045 |
| - model_size = compute_module_sizes(model)[""] |
| 1113 | + model_size = compute_module_persistent_sizes(model)[""] |
1046 | 1114 | with tempfile.TemporaryDirectory() as tmp_dir:
|
1047 | 1115 | model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
1048 | 1116 |
|
@@ -1076,7 +1144,7 @@ def test_disk_offload_with_safetensors(self):
|
1076 | 1144 | torch.manual_seed(0)
|
1077 | 1145 | base_output = model(**inputs_dict)
|
1078 | 1146 |
|
1079 |
| - model_size = compute_module_sizes(model)[""] |
| 1147 | + model_size = compute_module_persistent_sizes(model)[""] |
1080 | 1148 | with tempfile.TemporaryDirectory() as tmp_dir:
|
1081 | 1149 | model.cpu().save_pretrained(tmp_dir)
|
1082 | 1150 |
|
@@ -1104,7 +1172,7 @@ def test_model_parallelism(self):
|
1104 | 1172 | torch.manual_seed(0)
|
1105 | 1173 | base_output = model(**inputs_dict)
|
1106 | 1174 |
|
1107 |
| - model_size = compute_module_sizes(model)[""] |
| 1175 | + model_size = compute_module_persistent_sizes(model)[""] |
1108 | 1176 | # We test several splits of sizes to make sure it works.
|
1109 | 1177 | max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
|
1110 | 1178 | with tempfile.TemporaryDirectory() as tmp_dir:
|
@@ -1132,7 +1200,7 @@ def test_sharded_checkpoints(self):
|
1132 | 1200 |
|
1133 | 1201 | base_output = model(**inputs_dict)
|
1134 | 1202 |
|
1135 |
| - model_size = compute_module_sizes(model)[""] |
| 1203 | + model_size = compute_module_persistent_sizes(model)[""] |
1136 | 1204 | max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
1137 | 1205 | with tempfile.TemporaryDirectory() as tmp_dir:
|
1138 | 1206 | model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
@@ -1164,7 +1232,7 @@ def test_sharded_checkpoints_with_variant(self):
|
1164 | 1232 |
|
1165 | 1233 | base_output = model(**inputs_dict)
|
1166 | 1234 |
|
1167 |
| - model_size = compute_module_sizes(model)[""] |
| 1235 | + model_size = compute_module_persistent_sizes(model)[""] |
1168 | 1236 | max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
1169 | 1237 | variant = "fp16"
|
1170 | 1238 | with tempfile.TemporaryDirectory() as tmp_dir:
|
@@ -1204,7 +1272,7 @@ def test_sharded_checkpoints_device_map(self):
|
1204 | 1272 | torch.manual_seed(0)
|
1205 | 1273 | base_output = model(**inputs_dict)
|
1206 | 1274 |
|
1207 |
| - model_size = compute_module_sizes(model)[""] |
| 1275 | + model_size = compute_module_persistent_sizes(model)[""] |
1208 | 1276 | max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
1209 | 1277 | with tempfile.TemporaryDirectory() as tmp_dir:
|
1210 | 1278 | model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
@@ -1233,7 +1301,7 @@ def test_variant_sharded_ckpt_right_format(self):
|
1233 | 1301 | config, _ = self.prepare_init_args_and_inputs_for_common()
|
1234 | 1302 | model = self.model_class(**config).eval()
|
1235 | 1303 |
|
1236 |
| - model_size = compute_module_sizes(model)[""] |
| 1304 | + model_size = compute_module_persistent_sizes(model)[""] |
1237 | 1305 | max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
1238 | 1306 | variant = "fp16"
|
1239 | 1307 | with tempfile.TemporaryDirectory() as tmp_dir:
|
|
0 commit comments