Skip to content

Commit 0667c18

Browse files
committed
Add download cache hit benchmarks
1 parent 615c2b7 commit 0667c18

File tree

1 file changed

+68
-7
lines changed

1 file changed

+68
-7
lines changed

Libraries/BenchmarkHelpers/BenchmarkHelpers.swift

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Shared benchmark logic for measuring model loading performance.
1+
// Shared benchmark logic for measuring model loading and download performance.
22
// Integration packages inject their own Downloader and TokenizerLoader.
33

44
import Foundation
@@ -8,6 +8,37 @@ import MLXLLM
88
import MLXLMCommon
99
import MLXVLM
1010

11+
// MARK: - No-Op Tokenizer
12+
13+
/// A tokenizer loader that returns a stub tokenizer. Useful for benchmarking
14+
/// model loading in downloader integration packages that don't provide a
15+
/// real tokenizer.
16+
public struct NoOpTokenizerLoader: TokenizerLoader {
17+
public init() {}
18+
19+
public func load(from directory: URL) async throws -> any Tokenizer {
20+
NoOpTokenizer()
21+
}
22+
}
23+
24+
private struct NoOpTokenizer: Tokenizer {
25+
func encode(text: String, addSpecialTokens: Bool) -> [Int] { [] }
26+
func decode(tokenIds: [Int], skipSpecialTokens: Bool) -> String { "" }
27+
func convertTokenToId(_ token: String) -> Int? { nil }
28+
func convertIdToToken(_ id: Int) -> String? { nil }
29+
var bosToken: String? { nil }
30+
var eosToken: String? { nil }
31+
var unknownToken: String? { nil }
32+
33+
func applyChatTemplate(
34+
messages: [[String: any Sendable]],
35+
tools: [[String: any Sendable]]?,
36+
additionalContext: [String: any Sendable]?
37+
) throws -> [Int] {
38+
throw MLXLMCommon.TokenizerError.missingChatTemplate
39+
}
40+
}
41+
1142
// MARK: - Stats
1243

1344
public struct BenchmarkStats: Sendable {
@@ -32,10 +63,10 @@ public struct BenchmarkStats: Sendable {
3263

3364
public func printSummary(label: String) {
3465
print("\(label) results:")
35-
print(" Mean: \(String(format: "%.0f", mean))ms")
36-
print(" Median: \(String(format: "%.0f", median))ms")
66+
print(" Mean: \(String(format: "%.1f", mean))ms")
67+
print(" Median: \(String(format: "%.1f", median))ms")
3768
print(" StdDev: \(String(format: "%.1f", stdDev))ms")
38-
print(" Range: \(String(format: "%.0f", min))-\(String(format: "%.0f", max))ms")
69+
print(" Range: \(String(format: "%.1f", min))-\(String(format: "%.1f", max))ms")
3970
}
4071
}
4172

@@ -63,7 +94,7 @@ public func benchmarkLLMLoading(
6394
) { _ in }
6495
let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000
6596
times.append(elapsed)
66-
print("LLM load run \(i): \(String(format: "%.0f", elapsed))ms")
97+
print("LLM load run \(i): \(String(format: "%.1f", elapsed))ms")
6798
Memory.clearCache()
6899
}
69100

@@ -92,7 +123,7 @@ public func benchmarkVLMLoading(
92123
) { _ in }
93124
let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000
94125
times.append(elapsed)
95-
print("VLM load run \(i): \(String(format: "%.0f", elapsed))ms")
126+
print("VLM load run \(i): \(String(format: "%.1f", elapsed))ms")
96127
Memory.clearCache()
97128
}
98129

@@ -119,9 +150,39 @@ public func benchmarkEmbeddingLoading(
119150
) { _ in }
120151
let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000
121152
times.append(elapsed)
122-
print("Embedding load run \(i): \(String(format: "%.0f", elapsed))ms")
153+
print("Embedding load run \(i): \(String(format: "%.1f", elapsed))ms")
123154
Memory.clearCache()
124155
}
125156

126157
return BenchmarkStats(times: times)
127158
}
159+
160+
// MARK: - Download Benchmarks
161+
162+
/// Benchmark download cache hit performance. Ensures the model is cached with a warm-up
163+
/// download, then measures repeated cache lookups.
164+
public func benchmarkDownloadCacheHit(
165+
from downloader: any Downloader,
166+
modelId: String = "mlx-community/Qwen3-0.6B-4bit",
167+
runs: Int = 7
168+
) async throws -> BenchmarkStats {
169+
let patterns = ["*.safetensors", "*.json", "*.jinja"]
170+
171+
// Warm-up: ensure the model is cached
172+
_ = try await downloader.download(
173+
id: modelId, revision: "main", matching: patterns,
174+
useLatest: false, progressHandler: { _ in })
175+
176+
var times: [Double] = []
177+
for i in 1 ... runs {
178+
let start = CFAbsoluteTimeGetCurrent()
179+
_ = try await downloader.download(
180+
id: modelId, revision: "main", matching: patterns,
181+
useLatest: false, progressHandler: { _ in })
182+
let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000
183+
times.append(elapsed)
184+
print("Download cache hit run \(i): \(String(format: "%.1f", elapsed))ms")
185+
}
186+
187+
return BenchmarkStats(times: times)
188+
}

0 commit comments

Comments
 (0)