-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathmodel_tensor_sizes.py
More file actions
executable file
·181 lines (161 loc) · 7.45 KB
/
model_tensor_sizes.py
File metadata and controls
executable file
·181 lines (161 loc) · 7.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#!/usr/bin/env python3
#***************************************************************#
#** This script is part of Thireus' GGUF Tool Suite. **#
#** model_tensor_sizes.py is a tool that helps identify which **#
#** tensors are the heaviest, thus to be benchmarked. **#
#** **#
#** ********************************************************* **#
#** --------------- Updated: Feb-11-2026 -------------------- **#
#** ********************************************************* **#
#** **#
#** Author: Thireus <gguf@thireus.com> **#
#** **#
#** https://gguf.thireus.com/ **#
#** Thireus' GGUF Tool Suite - Quantize LLMs Like a Chef **#
#** · · ·~° **#
#** Λ,,Λ ₚₚₗ ·° ᵍᵍᵐˡ · ɪᴋ_ʟʟᴀᴍᴀ.ᴄᴘᴘ° ᴮᶠ¹⁶ · **#
#** (:·ω·) 。··° · ɢɢᴜғ ·°· ₕᵤ𝓰𝓰ᵢₙ𝓰𝒻ₐ𝒸ₑ ·° **#
#** / o―ヽニニフ)) · · ɪǫ3_xxs ~·° **#
#** し―-J **#
#** **#
#** Copyright © 2025 - Thireus. ₖₗD ₐₗₗ ₜₕₑ 𝓌ₐᵧ! **#
#***************************************************************#
#**PLEASE REFER TO THE README FILE FOR ADDITIONAL INFORMATION!**#
#***************************************************************#
"""
model_tensor_sizes.py — compute total tensor sizes per regex from a recipe and map file.
Usage:
./model_tensor_sizes.py [--bytes] [--sort] RECIPE_FILE MAP_FILE
Options:
--bytes Show raw byte counts instead of human-readable units.
--sort Output only regex lines sorted by total size (heaviest -> lightest).
-h, --help Show this help and exit.
Example:
chmod +x model_tensor_sizes.py
./model_tensor_sizes.py my.recipe my.map > my.recipe.sized
./model_tensor_sizes.py --bytes --sort my.recipe my.map > my.recipe.sorted.sized
"""
from __future__ import annotations
import sys
import re
import argparse
from pathlib import Path
from typing import Dict, List, Tuple
def parse_args():
p = argparse.ArgumentParser(
description="Compute and prepend total tensor sizes per regex from a recipe and map file.",
formatter_class=argparse.RawTextHelpFormatter
)
p.add_argument("--bytes", action="store_true", help="Show raw byte counts instead of human-readable units")
p.add_argument("--sort", action="store_true", help="Output only regex lines sorted by total size (heaviest -> lightest)")
p.add_argument("recipe_file", help=".recipe file path")
p.add_argument("map_file", help=".map file path")
return p.parse_args()
def human_readable(nbytes: int) -> str:
"""Return human-readable size (B, KB, MB, GB, TB)."""
if nbytes < 1024:
return f"{nbytes}B"
size = float(nbytes)
for unit in ("KB", "MB", "GB", "TB"):
size /= 1024.0
if size < 1024.0 or unit == "TB":
return f"{size:.2f} {unit}"
return f"{nbytes}B"
def parse_map_file(map_path: Path) -> Dict[str, int]:
"""Parse the .map file returning dict: tensor_name -> bytes."""
tensors: Dict[str, int] = {}
text = map_path.read_text(encoding="utf-8", errors="ignore")
for line in text.splitlines():
if not line.strip() or line.strip().startswith("#"):
continue
parts = line.split(":")
# Expect at least 3 parts so parts[2] is tensor name
if len(parts) < 3:
continue
name = parts[2].strip()
m = re.search(r"bytes=(\d+)", line)
if m:
try:
tensors[name] = int(m.group(1))
except ValueError:
# skip malformed numbers
continue
return tensors
def compile_regex(pattern: str):
"""Compile the pattern into a regex. If compile fails, return None."""
try:
return re.compile(pattern)
except re.error as e:
sys.stderr.write(f"Warning: invalid regex pattern: {pattern!r} -> {e}\n")
return None
def total_bytes_for_pattern(pattern: str, tensors: Dict[str, int]) -> int:
"""Sum bytes of all tensors whose name matches the regex pattern."""
regex = compile_regex(pattern)
if regex is None:
return 0
# match anywhere (the patterns likely contain ^/$ anchors already)
total = 0
for name, size in tensors.items():
if regex.search(name):
total += size
return total
def extract_regex_from_line(line: str) -> str:
"""Return the regex portion from a recipe line (part before first '=' or whitespace), trimmed."""
if "=" in line:
left = line.split("=", 1)[0]
else:
left = line
return left.strip().split(" ", 1)[0].strip()
def read_recipe_lines(recipe_path: Path) -> List[str]:
return recipe_path.read_text(encoding="utf-8", errors="ignore").splitlines()
def main():
args = parse_args()
recipe_path = Path(args.recipe_file)
map_path = Path(args.map_file)
if not recipe_path.exists():
sys.exit(f"Error: recipe file '{recipe_path}' not found.")
if not map_path.exists():
sys.exit(f"Error: map file '{map_path}' not found.")
tensors = parse_map_file(map_path)
if not tensors:
sys.stderr.write("Warning: no tensors parsed from map file (or map file missing bytes= entries).\n")
recipe_lines = read_recipe_lines(recipe_path)
# Collect regex lines and their totals
regex_entries: List[Tuple[str, str, int]] = []
# each tuple: (original_line, regex_pattern, total_bytes)
for line in recipe_lines:
stripped = line.strip()
# skip comments/blank when gathering regex entries; we still want to preserve them when not --sort
if not stripped or stripped.startswith("#"):
continue
pattern = extract_regex_from_line(line)
if not pattern:
continue
total = total_bytes_for_pattern(pattern, tensors)
regex_entries.append((line, pattern, total))
# If --sort is requested, output only regex lines sorted by total bytes (desc)
if args.sort:
# sort by total desc, stable
regex_entries.sort(key=lambda x: x[2], reverse=True)
for orig_line, _, total in regex_entries:
size_str = f"{total} B" if args.bytes else human_readable(total)
print(f"{size_str:>10} {orig_line}")
return
# Otherwise, preserve original file order and prepend sizes on regex lines; keep comments/blank lines
# Create a lookup of pattern -> total for fast access (pattern is the lhs portion used earlier)
pattern_to_total: Dict[str, int] = {p: t for (_, p, t) in ((e[0], e[1], e[2]) for e in regex_entries)}
# Note: if the same exact pattern appears multiple times, it's fine — we use the same total each time.
for line in recipe_lines:
stripped = line.strip()
if not stripped or stripped.startswith("#"):
print(line)
continue
pattern = extract_regex_from_line(line)
if not pattern:
print(line)
continue
total = pattern_to_total.get(pattern, 0)
size_str = f"{total} B" if args.bytes else human_readable(total)
print(f"{size_str:>10} {line}")
if __name__ == "__main__":
main()