Skip to content

Commit 1df32d7

Browse files
committed
wip: add HW requirements calculator
1 parent 23ffa83 commit 1df32d7

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

packages/hub/index.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,20 @@
11
export * from "./src";
2+
3+
// TODO: remove this before merging
4+
// Run with: npx ts-node index.ts
5+
import { getHardwareRequirements } from "./src/lib/hardware-requirements";
6+
(async () => {
7+
const models = [
8+
"hexgrad/Kokoro-82M",
9+
"microsoft/OmniParser-v2.0",
10+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
11+
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
12+
"NousResearch/DeepHermes-3-Llama-3-8B-Preview",
13+
"unsloth/DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit",
14+
];
15+
16+
for (const name of models) {
17+
const mem = await getHardwareRequirements({ name });
18+
console.log('mem', JSON.stringify(mem, null, 2));
19+
}
20+
})();
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import { ListFileEntry, listFiles } from "./list-files";
2+
3+
export interface MemoryRequirements {
4+
minimumGigabytes: number;
5+
recommendedGigabytes: number;
6+
};
7+
8+
export interface HardwareRequirements {
9+
name: string;
10+
memory: MemoryRequirements;
11+
};
12+
13+
export async function getHardwareRequirements(params: {
14+
/**
15+
* The model name in the format of `namespace/repo`.
16+
*/
17+
name: string;
18+
/**
19+
* The context size in tokens, default to 2048.
20+
*/
21+
contextSize?: number;
22+
}) {
23+
const files = await getFiles(params.name);
24+
const hasSafetensors = files.some((file) => file.path.endsWith(".safetensors"));
25+
const hasPytorch = files.some((file) => file.path.endsWith(".pth"));
26+
27+
// Get the total size of the model weight in bytes (we don't care about quantization scheme)
28+
let totalWeightBytes = 0;
29+
if (hasSafetensors) {
30+
totalWeightBytes = sumFileSize(files.filter((file) => file.path.endsWith(".safetensors")));
31+
} else if (hasPytorch) {
32+
totalWeightBytes = sumFileSize(files.filter((file) => file.path.endsWith(".pth")));
33+
}
34+
35+
// Calculate the memory for context window
36+
// TODO: this also scales in function of weight, to be implemented later
37+
const contextWindow = params.contextSize ?? 2048;
38+
const batchSize = 256; // a bit overhead for batching
39+
const contextMemoryBytes = (contextWindow + batchSize) * 0.5 * 1e6;
40+
41+
// Calculate the memory overhead
42+
const osOverheadBytes = Math.max(512 * 1e6, 0.2 * totalWeightBytes);
43+
44+
// Calculate the total memory requirements
45+
const totalMemoryGb = (totalWeightBytes + contextMemoryBytes + osOverheadBytes) / 1e9;
46+
47+
return {
48+
name: params.name,
49+
memory: {
50+
minimumGigabytes: totalMemoryGb,
51+
recommendedGigabytes: totalMemoryGb * 1.1,
52+
},
53+
} satisfies HardwareRequirements;
54+
}
55+
56+
async function getFiles(name: string): Promise<ListFileEntry[]> {
57+
const files: ListFileEntry[] = [];
58+
const cursor = listFiles({
59+
repo: {
60+
name,
61+
type: "model",
62+
},
63+
});
64+
for await (const entry of cursor) {
65+
files.push(entry);
66+
}
67+
return files;
68+
};
69+
70+
function sumFileSize(files: ListFileEntry[]): number {
71+
return files.reduce((total, file) => total + file.size, 0);
72+
}

0 commit comments

Comments
 (0)