Skip to content

Commit 574fe74

Browse files
committed
add compute_module_persistent_sizes
1 parent 13c5954 commit 574fe74

File tree

1 file changed

+78
-26
lines changed

1 file changed

+78
-26
lines changed

tests/models/test_modeling_common.py

Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
import unittest
2323
import unittest.mock as mock
2424
import uuid
25-
from typing import Dict, List, Tuple
25+
from collections import defaultdict
26+
from typing import Dict, List, Optional, Tuple, Union
2627

2728
import numpy as np
2829
import requests_mock
2930
import torch
30-
from accelerate.utils import compute_module_sizes
31+
import torch.nn as nn
32+
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
3133
from huggingface_hub import ModelCard, delete_repo, snapshot_download
3234
from huggingface_hub.utils import is_jinja_available
3335
from parameterized import parameterized
@@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
113115
out_queue.join()
114116

115117

118+
def named_persistent_module_tensors(
119+
module: nn.Module,
120+
recurse: bool = False,
121+
):
122+
"""
123+
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
124+
125+
Args:
126+
module (`torch.nn.Module`):
127+
The module we want the tensors on.
128+
recurse (`bool`, *optional`, defaults to `False`):
129+
Whether or not to go look in every submodule or just return the direct parameters and buffers.
130+
"""
131+
yield from module.named_parameters(recurse=recurse)
132+
133+
for named_buffer in module.named_buffers(recurse=recurse):
134+
name, _ = named_buffer
135+
# Get parent by splitting on dots and traversing the model
136+
parent = module
137+
if "." in name:
138+
parent_name = name.rsplit(".", 1)[0]
139+
for part in parent_name.split("."):
140+
parent = getattr(parent, part)
141+
name = name.split(".")[-1]
142+
if name not in parent._non_persistent_buffers_set:
143+
yield named_buffer
144+
145+
146+
def compute_module_persistent_sizes(
147+
model: nn.Module,
148+
dtype: Optional[Union[str, torch.device]] = None,
149+
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
150+
):
151+
"""
152+
Compute the size of each submodule of a given model (parameters + persistent buffers).
153+
"""
154+
if dtype is not None:
155+
dtype = _get_proper_dtype(dtype)
156+
dtype_size = dtype_byte_size(dtype)
157+
if special_dtypes is not None:
158+
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
159+
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
160+
module_sizes = defaultdict(int)
161+
162+
module_list = []
163+
164+
module_list = named_persistent_module_tensors(model, recurse=True)
165+
166+
for name, tensor in module_list:
167+
if special_dtypes is not None and name in special_dtypes:
168+
size = tensor.numel() * special_dtypes_size[name]
169+
elif dtype is None:
170+
size = tensor.numel() * dtype_byte_size(tensor.dtype)
171+
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
172+
# According to the code in set_module_tensor_to_device, these types won't be converted
173+
# so use their original size here
174+
size = tensor.numel() * dtype_byte_size(tensor.dtype)
175+
else:
176+
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
177+
name_parts = name.split(".")
178+
for idx in range(len(name_parts) + 1):
179+
module_sizes[".".join(name_parts[:idx])] += size
180+
181+
return module_sizes
182+
183+
116184
class ModelUtilsTest(unittest.TestCase):
117185
def tearDown(self):
118186
super().tearDown()
@@ -1012,9 +1080,7 @@ def test_cpu_offload(self):
10121080
torch.manual_seed(0)
10131081
base_output = model(**inputs_dict)
10141082

1015-
model_size = compute_module_sizes(model)[""]
1016-
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1017-
model_size = model_size - buffer_size
1083+
model_size = compute_module_persistent_sizes(model)[""]
10181084
# We test several splits of sizes to make sure it works.
10191085
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
10201086
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1044,9 +1110,7 @@ def test_disk_offload_without_safetensors(self):
10441110
torch.manual_seed(0)
10451111
base_output = model(**inputs_dict)
10461112

1047-
model_size = compute_module_sizes(model)[""]
1048-
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1049-
model_size = model_size - buffer_size
1113+
model_size = compute_module_persistent_sizes(model)[""]
10501114
with tempfile.TemporaryDirectory() as tmp_dir:
10511115
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
10521116

@@ -1080,9 +1144,7 @@ def test_disk_offload_with_safetensors(self):
10801144
torch.manual_seed(0)
10811145
base_output = model(**inputs_dict)
10821146

1083-
model_size = compute_module_sizes(model)[""]
1084-
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1085-
model_size = model_size - buffer_size
1147+
model_size = compute_module_persistent_sizes(model)[""]
10861148
with tempfile.TemporaryDirectory() as tmp_dir:
10871149
model.cpu().save_pretrained(tmp_dir)
10881150

@@ -1110,9 +1172,7 @@ def test_model_parallelism(self):
11101172
torch.manual_seed(0)
11111173
base_output = model(**inputs_dict)
11121174

1113-
model_size = compute_module_sizes(model)[""]
1114-
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1115-
model_size = model_size - buffer_size
1175+
model_size = compute_module_persistent_sizes(model)[""]
11161176
# We test several splits of sizes to make sure it works.
11171177
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
11181178
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1140,9 +1200,7 @@ def test_sharded_checkpoints(self):
11401200

11411201
base_output = model(**inputs_dict)
11421202

1143-
model_size = compute_module_sizes(model)[""]
1144-
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1145-
model_size = model_size - buffer_size
1203+
model_size = compute_module_persistent_sizes(model)[""]
11461204
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
11471205
with tempfile.TemporaryDirectory() as tmp_dir:
11481206
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
@@ -1174,9 +1232,7 @@ def test_sharded_checkpoints_with_variant(self):
11741232

11751233
base_output = model(**inputs_dict)
11761234

1177-
model_size = compute_module_sizes(model)[""]
1178-
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1179-
model_size = model_size - buffer_size
1235+
model_size = compute_module_persistent_sizes(model)[""]
11801236
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
11811237
variant = "fp16"
11821238
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1216,9 +1272,7 @@ def test_sharded_checkpoints_device_map(self):
12161272
torch.manual_seed(0)
12171273
base_output = model(**inputs_dict)
12181274

1219-
model_size = compute_module_sizes(model)[""]
1220-
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1221-
model_size = model_size - buffer_size
1275+
model_size = compute_module_persistent_sizes(model)[""]
12221276
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
12231277
with tempfile.TemporaryDirectory() as tmp_dir:
12241278
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
@@ -1247,9 +1301,7 @@ def test_variant_sharded_ckpt_right_format(self):
12471301
config, _ = self.prepare_init_args_and_inputs_for_common()
12481302
model = self.model_class(**config).eval()
12491303

1250-
model_size = compute_module_sizes(model)[""]
1251-
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1252-
model_size = model_size - buffer_size
1304+
model_size = compute_module_persistent_sizes(model)[""]
12531305
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
12541306
variant = "fp16"
12551307
with tempfile.TemporaryDirectory() as tmp_dir:

0 commit comments

Comments
 (0)