Skip to content

Commit 13c5954

Browse files
committed
substract buffer size
1 parent b32bf00 commit 13c5954

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tests/models/test_modeling_common.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,8 @@ def test_cpu_offload(self):
10131013
base_output = model(**inputs_dict)
10141014

10151015
model_size = compute_module_sizes(model)[""]
1016+
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1017+
model_size = model_size - buffer_size
10161018
# We test several splits of sizes to make sure it works.
10171019
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
10181020
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1043,6 +1045,8 @@ def test_disk_offload_without_safetensors(self):
10431045
base_output = model(**inputs_dict)
10441046

10451047
model_size = compute_module_sizes(model)[""]
1048+
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1049+
model_size = model_size - buffer_size
10461050
with tempfile.TemporaryDirectory() as tmp_dir:
10471051
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
10481052

@@ -1077,6 +1081,8 @@ def test_disk_offload_with_safetensors(self):
10771081
base_output = model(**inputs_dict)
10781082

10791083
model_size = compute_module_sizes(model)[""]
1084+
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1085+
model_size = model_size - buffer_size
10801086
with tempfile.TemporaryDirectory() as tmp_dir:
10811087
model.cpu().save_pretrained(tmp_dir)
10821088

@@ -1105,6 +1111,8 @@ def test_model_parallelism(self):
11051111
base_output = model(**inputs_dict)
11061112

11071113
model_size = compute_module_sizes(model)[""]
1114+
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1115+
model_size = model_size - buffer_size
11081116
# We test several splits of sizes to make sure it works.
11091117
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
11101118
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1133,6 +1141,8 @@ def test_sharded_checkpoints(self):
11331141
base_output = model(**inputs_dict)
11341142

11351143
model_size = compute_module_sizes(model)[""]
1144+
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1145+
model_size = model_size - buffer_size
11361146
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
11371147
with tempfile.TemporaryDirectory() as tmp_dir:
11381148
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
@@ -1165,6 +1175,8 @@ def test_sharded_checkpoints_with_variant(self):
11651175
base_output = model(**inputs_dict)
11661176

11671177
model_size = compute_module_sizes(model)[""]
1178+
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1179+
model_size = model_size - buffer_size
11681180
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
11691181
variant = "fp16"
11701182
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1205,6 +1217,8 @@ def test_sharded_checkpoints_device_map(self):
12051217
base_output = model(**inputs_dict)
12061218

12071219
model_size = compute_module_sizes(model)[""]
1220+
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1221+
model_size = model_size - buffer_size
12081222
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
12091223
with tempfile.TemporaryDirectory() as tmp_dir:
12101224
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
@@ -1234,6 +1248,8 @@ def test_variant_sharded_ckpt_right_format(self):
12341248
model = self.model_class(**config).eval()
12351249

12361250
model_size = compute_module_sizes(model)[""]
1251+
buffer_size = compute_module_sizes(model, buffers_only=True)[""]
1252+
model_size = model_size - buffer_size
12371253
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
12381254
variant = "fp16"
12391255
with tempfile.TemporaryDirectory() as tmp_dir:

0 commit comments

Comments
 (0)