Skip to content

Commit 91c0941

Browse files
committed
gguf: add memory calculator to CLI
1 parent 9abb7f5 commit 91c0941

File tree

2 files changed

+150
-22
lines changed

2 files changed

+150
-22
lines changed

packages/gguf/src/cli.ts

Lines changed: 110 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env node
22

3-
import { GGMLQuantizationType, gguf } from ".";
3+
import { GGMLQuantizationType, gguf, GGUFParseOutput } from ".";
4+
import { GGML_QUANT_SIZES } from "./quant-descriptions";
45

56
interface PrintColumnHeader {
67
name: string;
@@ -11,7 +12,21 @@ interface PrintColumnHeader {
1112
const mapDtypeToName = Object.fromEntries(Object.entries(GGMLQuantizationType).map(([name, value]) => [value, name]));
1213

1314
async function main() {
14-
const ggufPath = process.argv[2];
15+
let ggufPath = "";
16+
let showTensors = false;
17+
for (let i = 2; i < process.argv.length; i++) {
18+
if (process.argv[i] === "--show-tensor") {
19+
showTensors = true;
20+
} else {
21+
ggufPath = process.argv[i];
22+
}
23+
}
24+
25+
if (!ggufPath.length) {
26+
console.error("Usage: gguf-view [--show-tensor] <path/to/gguf>");
27+
process.exit(1);
28+
}
29+
1530
const { metadata, tensorInfos } = await gguf(ggufPath, {
1631
allowLocalFile: true,
1732
});
@@ -43,29 +58,102 @@ async function main() {
4358
);
4459

4560
console.log();
46-
console.log(`* Dumping ${tensorInfos.length} tensor(s)`);
61+
console.log(`* Memory usage estimation`);
62+
const kvUsage = calcMemoryUsage(metadata as GGUFParseOutput<{ strict: false }>["metadata"], 4096);
63+
let modelWeightInBytes = 0;
64+
for (const tensorInfo of tensorInfos) {
65+
const nElem = Number(tensorInfo.shape.reduce((a, b) => a * b, 1n));
66+
const tensorSizeInBytes = nElem * (GGML_QUANT_SIZES[tensorInfo.dtype] / 8);
67+
modelWeightInBytes += tensorSizeInBytes;
68+
}
69+
const overhead =
70+
calcMemoryUsage(metadata as GGUFParseOutput<{ strict: false }>["metadata"], 256).totalBytes +
71+
modelWeightInBytes * 0.05;
72+
const totalMemoryUsage = kvUsage.totalBytes + overhead + modelWeightInBytes;
4773
printTable(
74+
[{ name: "Item" }, { name: "Memory usage", alignRight: true }],
4875
[
49-
{ name: "Idx", alignRight: true },
50-
{ name: "Num Elements", alignRight: true },
51-
{ name: "Shape" },
52-
{ name: "Data Type" },
53-
{ name: "Name" },
54-
],
55-
tensorInfos.map((tensorInfo, i) => {
56-
const shape = [1n, 1n, 1n, 1n];
57-
tensorInfo.shape.forEach((dim, i) => {
58-
shape[i] = dim;
59-
});
60-
return [
61-
(i + 1).toString(),
62-
shape.reduce((acc, n) => acc * n, 1n).toString(),
63-
shape.map((n) => n.toString().padStart(6)).join(", "),
64-
mapDtypeToName[tensorInfo.dtype],
65-
tensorInfo.name,
66-
];
67-
})
76+
["K cache", (kvUsage.totalBytesK / 1e9).toFixed(2) + " GB"],
77+
["V cache", (kvUsage.totalBytesV / 1e9).toFixed(2) + " GB"],
78+
["Weight", (modelWeightInBytes / 1e9).toFixed(2) + " GB"],
79+
["Overhead", (overhead / 1e9).toFixed(2) + " GB"],
80+
["", "---"],
81+
["TOTAL", (totalMemoryUsage / 1e9).toFixed(2) + " GB"],
82+
]
6883
);
84+
85+
if (showTensors) {
86+
console.log();
87+
console.log(`* Dumping ${tensorInfos.length} tensor(s)`);
88+
printTable(
89+
[
90+
{ name: "Idx", alignRight: true },
91+
{ name: "Num Elements", alignRight: true },
92+
{ name: "Shape" },
93+
{ name: "Data Type" },
94+
{ name: "Name" },
95+
],
96+
tensorInfos.map((tensorInfo, i) => {
97+
const shape = [1n, 1n, 1n, 1n];
98+
tensorInfo.shape.forEach((dim, i) => {
99+
shape[i] = dim;
100+
});
101+
return [
102+
(i + 1).toString(),
103+
shape.reduce((acc, n) => acc * n, 1n).toString(),
104+
shape.map((n) => n.toString().padStart(6)).join(", "),
105+
mapDtypeToName[tensorInfo.dtype],
106+
tensorInfo.name,
107+
];
108+
})
109+
);
110+
} else {
111+
console.log();
112+
console.log(`* Use --show-tensor to display tensor information`);
113+
}
114+
}
115+
116+
function calcMemoryUsage(
117+
metadata: GGUFParseOutput<{ strict: false }>["metadata"],
118+
kvSize: number,
119+
kvTypeK: GGMLQuantizationType = GGMLQuantizationType.F16,
120+
kvTypeV: GGMLQuantizationType = GGMLQuantizationType.F16
121+
) {
122+
const arch = metadata["general.architecture"] ?? "unknown";
123+
const n_embd = (metadata[`${arch}.embedding_length`] as number) ?? 0;
124+
const n_head = (metadata[`${arch}.attention.head_count`] as number) ?? 0;
125+
const n_embd_head_k = (metadata[`${arch}.attention.key_length`] as number) ?? n_embd / n_head;
126+
const n_embd_head_v = (metadata[`${arch}.attention.value_length`] as number) ?? n_embd / n_head;
127+
const n_head_kv = (metadata[`${arch}.attention.head_count_kv`] as number[] | number) ?? [];
128+
const n_layer = (metadata[`${arch}.block_count`] as number) ?? 0;
129+
130+
const n_head_kv_arr = Array(n_layer).fill(n_head);
131+
if (Array.isArray(n_head_kv)) {
132+
for (let i = 0; i < n_layer; i++) {
133+
if (n_head_kv[i]) {
134+
n_head_kv_arr[i] = n_head_kv[i];
135+
}
136+
}
137+
} else {
138+
for (let i = 0; i < n_layer; i++) {
139+
n_head_kv_arr[i] = n_head_kv;
140+
}
141+
}
142+
143+
let totalElemsK = 0;
144+
let totalElemsV = 0;
145+
for (let i = 0; i < n_layer; i++) {
146+
const n_embd_k_gqa = n_embd_head_k * n_head_kv_arr[i];
147+
const n_embd_v_gqa = n_embd_head_v * n_head_kv_arr[i];
148+
totalElemsK += n_embd_k_gqa * kvSize;
149+
totalElemsV += n_embd_v_gqa * kvSize;
150+
}
151+
152+
return {
153+
totalBytesK: totalElemsK * (GGML_QUANT_SIZES[kvTypeK] / 8),
154+
totalBytesV: totalElemsV * (GGML_QUANT_SIZES[kvTypeV] / 8),
155+
totalBytes: (totalElemsK + totalElemsV) * (GGML_QUANT_SIZES[kvTypeV] / 8),
156+
};
69157
}
70158

71159
function printTable(header: PrintColumnHeader[], rows: string[][], leftPad = 2) {

packages/gguf/src/quant-descriptions.ts

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,43 @@ export const GGUF_QUANT_DESCRIPTIONS: Record<GGMLQuantizationType, { txt: string
125125
src_url: "https://en.wikipedia.org/wiki/Bfloat16_floating-point_format",
126126
},
127127
};
128+
129+
const QK_K = 256;
130+
const calcBPW = (blockSize: number, typeSize: number) => {
131+
return (typeSize * 8) / blockSize;
132+
};
133+
134+
// map quantization type to element size in bits per weight (example: Q4_K -> 4.5 bpw)
135+
export const GGML_QUANT_SIZES = {
136+
[GGMLQuantizationType.F32]: calcBPW(1, 4),
137+
[GGMLQuantizationType.F16]: calcBPW(1, 2),
138+
[GGMLQuantizationType.Q4_0]: calcBPW(32, 2 + 16),
139+
[GGMLQuantizationType.Q4_1]: calcBPW(32, 2 + 2 + 16),
140+
[GGMLQuantizationType.Q5_0]: calcBPW(32, 2 + 4 + 16),
141+
[GGMLQuantizationType.Q5_1]: calcBPW(32, 2 + 2 + 4 + 16),
142+
[GGMLQuantizationType.Q8_0]: calcBPW(32, 2 + 32),
143+
[GGMLQuantizationType.Q8_1]: calcBPW(32, 4 + 4 + 32),
144+
[GGMLQuantizationType.Q2_K]: calcBPW(256, 2 + 2 + QK_K / 16 + QK_K / 4),
145+
[GGMLQuantizationType.Q3_K]: calcBPW(256, 2 + QK_K / 4 + QK_K / 8 + 12),
146+
[GGMLQuantizationType.Q4_K]: calcBPW(256, 2 + 2 + QK_K / 2 + 12),
147+
[GGMLQuantizationType.Q5_K]: calcBPW(256, 2 + 2 + QK_K / 2 + QK_K / 8 + 12),
148+
[GGMLQuantizationType.Q6_K]: calcBPW(256, 2 + QK_K / 2 + QK_K / 4 + QK_K / 16),
149+
[GGMLQuantizationType.Q8_K]: calcBPW(256, 4 + QK_K + QK_K / 8),
150+
[GGMLQuantizationType.IQ2_XXS]: calcBPW(256, 2 + QK_K / 4),
151+
[GGMLQuantizationType.IQ2_XS]: calcBPW(256, 2 + QK_K / 4 + QK_K / 32),
152+
[GGMLQuantizationType.IQ3_XXS]: calcBPW(256, 2 + QK_K / 4 + QK_K / 8),
153+
[GGMLQuantizationType.IQ1_S]: calcBPW(256, 2 + QK_K / 8 + QK_K / 16),
154+
[GGMLQuantizationType.IQ4_NL]: calcBPW(32, 2 + 16),
155+
[GGMLQuantizationType.IQ3_S]: calcBPW(256, 2 + QK_K / 4 + QK_K / 8 + QK_K / 32 + 4),
156+
[GGMLQuantizationType.IQ2_S]: calcBPW(256, 2 + QK_K / 4 + QK_K / 16),
157+
[GGMLQuantizationType.IQ4_XS]: calcBPW(256, 2 + 2 + QK_K / 2 + QK_K / 64),
158+
[GGMLQuantizationType.I8]: calcBPW(1, 1),
159+
[GGMLQuantizationType.I16]: calcBPW(1, 2),
160+
[GGMLQuantizationType.I32]: calcBPW(1, 4),
161+
[GGMLQuantizationType.I64]: calcBPW(1, 8),
162+
[GGMLQuantizationType.F64]: calcBPW(1, 8),
163+
[GGMLQuantizationType.IQ1_M]: calcBPW(256, QK_K / 8 + QK_K / 16 + QK_K / 32),
164+
[GGMLQuantizationType.BF16]: calcBPW(1, 2),
165+
// [GGMLQuantizationType.TQ1_0]: calcBPW(256, 2 + 4 * 13),
166+
// [GGMLQuantizationType.TQ2_0]: calcBPW(256, 2 + 64),
167+
};

0 commit comments

Comments
 (0)