Skip to content

Commit ada52c9

Browse files
committed
Centralize file download patterns
1 parent ea87cd0 commit ada52c9

File tree

5 files changed

+14
-8
lines changed

5 files changed

+14
-8
lines changed

Libraries/BenchmarkHelpers/BenchmarkHelpers.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,6 @@ public func loadDecodingBenchmarkText(
196196
)
197197
}
198198

199-
private let tokenizerDownloadPatterns = ["*.json", "*.jinja"]
200-
201199
private func resolveTokenizerDirectory(
202200
from downloader: any Downloader,
203201
configuration: MLXLMCommon.ModelConfiguration,
@@ -444,7 +442,7 @@ public func benchmarkDownloadCacheHit(
444442
modelId: String = "mlx-community/Qwen3-0.6B-4bit",
445443
runs: Int = BenchmarkDefaults.downloadRuns
446444
) async throws -> BenchmarkStats {
447-
let patterns = ["*.safetensors", "*.json", "*.jinja"]
445+
let patterns = modelDownloadPatterns
448446

449447
// Warm-up: ensure the model is cached
450448
_ = try await downloader.download(

Libraries/MLXEmbedders/Load.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func resolveDirectories(
8383
case .id(let id, let revision):
8484
modelDirectory = try await downloader.download(
8585
id: id, revision: revision,
86-
matching: ["*.safetensors", "*.json"],
86+
matching: modelDownloadPatterns,
8787
useLatest: useLatest,
8888
progressHandler: progressHandler)
8989
case .directory(let directory):
@@ -95,7 +95,7 @@ func resolveDirectories(
9595
case .id(let id, let revision):
9696
tokenizerDirectory = try await downloader.download(
9797
id: id, revision: revision,
98-
matching: ["*.json"],
98+
matching: tokenizerDownloadPatterns,
9999
useLatest: useLatest,
100100
progressHandler: { _ in })
101101
case .directory(let directory):

Libraries/MLXLMCommon/Downloader.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public protocol Downloader: Sendable {
2121
/// - revision: Optional revision (branch, tag, commit hash, version number).
2222
/// Providers without versioning receive `nil`.
2323
/// - patterns: Glob patterns to filter which files to download
24-
/// (e.g. `["*.safetensors", "*.json"]`)
24+
/// (e.g. `["*.safetensors", "*.json", "*.jinja"]`)
2525
/// - useLatest: When `true`, check the provider for updates even if a cached
2626
/// version exists. When `false`, return the cached version if available.
2727
/// - progressHandler: Callback for download progress

Libraries/MLXLMCommon/ModelFactory.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
import Foundation
44

5+
/// File patterns required to resolve a tokenizer without downloading model weights.
6+
package let tokenizerDownloadPatterns = ["*.json", "*.jinja"]
7+
package let modelDownloadPatterns = ["*.safetensors"] + tokenizerDownloadPatterns
8+
59
public enum ModelFactoryError: LocalizedError {
610
case unsupportedModelType(String)
711
case unsupportedProcessorType(String)
@@ -197,7 +201,7 @@ public func resolve(
197201
case .id(let id, let revision):
198202
modelDirectory = try await downloader.download(
199203
id: id, revision: revision,
200-
matching: ["*.safetensors", "*.json", "*.jinja"],
204+
matching: modelDownloadPatterns,
201205
useLatest: useLatest,
202206
progressHandler: progressHandler)
203207
case .directory(let directory):
@@ -209,7 +213,7 @@ public func resolve(
209213
case .id(let id, let revision):
210214
tokenizerDirectory = try await downloader.download(
211215
id: id, revision: revision,
212-
matching: ["*.json", "*.jinja"],
216+
matching: tokenizerDownloadPatterns,
213217
useLatest: useLatest,
214218
progressHandler: { _ in })
215219
case .directory(let directory):

Tests/MLXLMTests/ResolveTests.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ private final class LockIsolated<Value: Sendable>: @unchecked Sendable {
6767
#expect(downloader.calls.value.count == 1)
6868
#expect(downloader.calls.value[0].id == "org/model")
6969
#expect(downloader.calls.value[0].revision == "abc123")
70+
#expect(downloader.calls.value[0].patterns.contains("*.jinja"))
7071

7172
// No separate tokenizer download, so both point to the model directory.
7273
#expect(resolved.modelDirectory == resolved.tokenizerDirectory)
@@ -91,6 +92,7 @@ private final class LockIsolated<Value: Sendable>: @unchecked Sendable {
9192
// Tokenizer download uses nil revision (provider default).
9293
#expect(downloader.calls.value[1].id == "org/tokenizer")
9394
#expect(downloader.calls.value[1].revision == nil)
95+
#expect(downloader.calls.value[1].patterns.contains("*.jinja"))
9496

9597
// Model and tokenizer come from different repos, so directories differ.
9698
#expect(resolved.modelDirectory != resolved.tokenizerDirectory)
@@ -113,6 +115,7 @@ private final class LockIsolated<Value: Sendable>: @unchecked Sendable {
113115

114116
#expect(downloader.calls.value[1].id == "org/tokenizer")
115117
#expect(downloader.calls.value[1].revision == "tok-v2")
118+
#expect(downloader.calls.value[1].patterns.contains("*.jinja"))
116119

117120
// Model and tokenizer come from different repos, so directories differ.
118121
#expect(resolved.modelDirectory != resolved.tokenizerDirectory)
@@ -150,6 +153,7 @@ private final class LockIsolated<Value: Sendable>: @unchecked Sendable {
150153
#expect(downloader.calls.value.count == 1)
151154
#expect(downloader.calls.value[0].id == "org/tokenizer")
152155
#expect(downloader.calls.value[0].revision == "v3")
156+
#expect(downloader.calls.value[0].patterns.contains("*.jinja"))
153157

154158
#expect(resolved.modelDirectory == localDir)
155159
#expect(resolved.tokenizerDirectory != localDir)

0 commit comments

Comments
 (0)