1- // Shared benchmark logic for measuring model loading performance.
1+ // Shared benchmark logic for measuring model loading and download performance.
22// Integration packages inject their own Downloader and TokenizerLoader.
33
44import Foundation
@@ -8,6 +8,37 @@ import MLXLLM
88import MLXLMCommon
99import MLXVLM
1010
11+ // MARK: - No-Op Tokenizer
12+
13+ /// A tokenizer loader that returns a stub tokenizer. Useful for benchmarking
14+ /// model loading in downloader integration packages that don't provide a
15+ /// real tokenizer.
16+ public struct NoOpTokenizerLoader : TokenizerLoader {
17+ public init ( ) { }
18+
19+ public func load( from directory: URL ) async throws -> any Tokenizer {
20+ NoOpTokenizer ( )
21+ }
22+ }
23+
24+ private struct NoOpTokenizer : Tokenizer {
25+ func encode( text: String , addSpecialTokens: Bool ) -> [ Int ] { [ ] }
26+ func decode( tokenIds: [ Int ] , skipSpecialTokens: Bool ) -> String { " " }
27+ func convertTokenToId( _ token: String ) -> Int ? { nil }
28+ func convertIdToToken( _ id: Int ) -> String ? { nil }
29+ var bosToken : String ? { nil }
30+ var eosToken : String ? { nil }
31+ var unknownToken : String ? { nil }
32+
33+ func applyChatTemplate(
34+ messages: [ [ String : any Sendable ] ] ,
35+ tools: [ [ String : any Sendable ] ] ? ,
36+ additionalContext: [ String : any Sendable ] ?
37+ ) throws -> [ Int ] {
38+ throw MLXLMCommon . TokenizerError. missingChatTemplate
39+ }
40+ }
41+
1142// MARK: - Stats
1243
1344public struct BenchmarkStats : Sendable {
@@ -32,10 +63,10 @@ public struct BenchmarkStats: Sendable {
3263
3364 public func printSummary( label: String ) {
3465 print ( " \( label) results: " )
35- print ( " Mean: \( String ( format: " %.0f " , mean) ) ms " )
36- print ( " Median: \( String ( format: " %.0f " , median) ) ms " )
66+ print ( " Mean: \( String ( format: " %.1f " , mean) ) ms " )
67+ print ( " Median: \( String ( format: " %.1f " , median) ) ms " )
3768 print ( " StdDev: \( String ( format: " %.1f " , stdDev) ) ms " )
38- print ( " Range: \( String ( format: " %.0f " , min) ) - \( String ( format: " %.0f " , max) ) ms " )
69+ print ( " Range: \( String ( format: " %.1f " , min) ) - \( String ( format: " %.1f " , max) ) ms " )
3970 }
4071}
4172
@@ -63,7 +94,7 @@ public func benchmarkLLMLoading(
6394 ) { _ in }
6495 let elapsed = ( CFAbsoluteTimeGetCurrent ( ) - start) * 1000
6596 times. append ( elapsed)
66- print ( " LLM load run \( i) : \( String ( format: " %.0f " , elapsed) ) ms " )
97+ print ( " LLM load run \( i) : \( String ( format: " %.1f " , elapsed) ) ms " )
6798 Memory . clearCache ( )
6899 }
69100
@@ -92,7 +123,7 @@ public func benchmarkVLMLoading(
92123 ) { _ in }
93124 let elapsed = ( CFAbsoluteTimeGetCurrent ( ) - start) * 1000
94125 times. append ( elapsed)
95- print ( " VLM load run \( i) : \( String ( format: " %.0f " , elapsed) ) ms " )
126+ print ( " VLM load run \( i) : \( String ( format: " %.1f " , elapsed) ) ms " )
96127 Memory . clearCache ( )
97128 }
98129
@@ -119,9 +150,39 @@ public func benchmarkEmbeddingLoading(
119150 ) { _ in }
120151 let elapsed = ( CFAbsoluteTimeGetCurrent ( ) - start) * 1000
121152 times. append ( elapsed)
122- print ( " Embedding load run \( i) : \( String ( format: " %.0f " , elapsed) ) ms " )
153+ print ( " Embedding load run \( i) : \( String ( format: " %.1f " , elapsed) ) ms " )
123154 Memory . clearCache ( )
124155 }
125156
126157 return BenchmarkStats ( times: times)
127158}
159+
160+ // MARK: - Download Benchmarks
161+
162+ /// Benchmark download cache hit performance. Ensures the model is cached with a warm-up
163+ /// download, then measures repeated cache lookups.
164+ public func benchmarkDownloadCacheHit(
165+ from downloader: any Downloader ,
166+ modelId: String = " mlx-community/Qwen3-0.6B-4bit " ,
167+ runs: Int = 7
168+ ) async throws -> BenchmarkStats {
169+ let patterns = [ " *.safetensors " , " *.json " , " *.jinja " ]
170+
171+ // Warm-up: ensure the model is cached
172+ _ = try await downloader. download (
173+ id: modelId, revision: " main " , matching: patterns,
174+ useLatest: false , progressHandler: { _ in } )
175+
176+ var times : [ Double ] = [ ]
177+ for i in 1 ... runs {
178+ let start = CFAbsoluteTimeGetCurrent ( )
179+ _ = try await downloader. download (
180+ id: modelId, revision: " main " , matching: patterns,
181+ useLatest: false , progressHandler: { _ in } )
182+ let elapsed = ( CFAbsoluteTimeGetCurrent ( ) - start) * 1000
183+ times. append ( elapsed)
184+ print ( " Download cache hit run \( i) : \( String ( format: " %.1f " , elapsed) ) ms " )
185+ }
186+
187+ return BenchmarkStats ( times: times)
188+ }
0 commit comments