Skip to content

Commit ebd10d3

Browse files
committed
Remove nn.Embedding layer from model size
1 parent 2bd12cb commit ebd10d3

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

generate.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,18 @@ def _load_model(checkpoint_path, device, precision, use_tp):
245245
model = model.to(device=device, dtype=precision)
246246
return model.eval()
247247

248+
def _get_model_size(model):
249+
model_size = 0
250+
for name, child in model.named_children():
251+
if not isinstance(child, torch.nn.Embedding):
252+
model_size += sum(
253+
[
254+
p.numel() * p.dtype.itemsize
255+
for p in itertools.chain(child.parameters(), child.buffers())
256+
]
257+
)
258+
return model_size
259+
248260
B_INST, E_INST = "[INST]", "[/INST]"
249261

250262
def main(
@@ -301,7 +313,7 @@ def main(
301313
prompt_length = encoded.size(0)
302314

303315
torch.manual_seed(1234)
304-
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
316+
model_size = _get_model_size(model)
305317
if compile:
306318
if is_speculative and use_tp: # and ("cuda" in device):
307319
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case

0 commit comments

Comments
 (0)