-
Notifications
You must be signed in to change notification settings - Fork 107
Expand file tree
/
Copy pathModels.swift
More file actions
186 lines (166 loc) · 6.98 KB
/
Models.swift
File metadata and controls
186 lines (166 loc) · 6.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
// Copyright © 2024 Apple Inc.
import Foundation
import MLXLMCommon
/// A registry and configuration provider for embedding models.
///
/// `ModelConfiguration` manages how models are identified (either via Hugging Face Hub IDs or local file URLs)
/// and provides a mechanism to override tokenizer settings. It includes a global registry of
/// well-known models (like BGE, E5, and Snowflake Arctic) to simplify initialization.
///
/// ### Example
/// ```swift
/// // Using a pre-registered model
/// let config = ModelConfiguration.bge_small
///
/// // Using a custom local directory
/// let customConfig = ModelConfiguration(directory: myURL)
/// ```
public struct ModelConfiguration: Sendable {
/// The backing storage for the model's location.
public enum Identifier: Sendable {
/// A Hugging Face Hub repository identifier (e.g., "BAAI/bge-small-en-v1.5").
case id(String, revision: String = "main")
/// A file system URL pointing to a local model directory.
case directory(URL)
}
/// The model's identifier (ID or Directory).
public var id: Identifier
/// A display-friendly name for the model.
///
/// For Hub models, this returns the repo ID. For local directories,
/// it returns a path-based name (e.g., "ParentDir/ModelDir").
public var name: String {
switch id {
case .id(let string, _):
string
case .directory(let url):
url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent
}
}
/// Where to load the tokenizer from when it differs from the model directory.
///
/// - `.id`: download from a remote provider (requires a ``Downloader``)
/// - `.directory`: load from a local path
/// - `nil`: use the same directory as the model
public let tokenizerSource: TokenizerSource?
/// Initializes a configuration using a Hub repository ID.
/// - Parameters:
/// - id: The Hugging Face repo ID.
/// - revision: The Git revision to use (defaults to "main").
/// - tokenizerSource: Optional alternate source for the tokenizer.
public init(
id: String,
revision: String = "main",
tokenizerSource: TokenizerSource? = nil
) {
self.id = .id(id, revision: revision)
self.tokenizerSource = tokenizerSource
}
/// Initializes a configuration using a local directory.
/// - Parameters:
/// - directory: The `URL` of the model on disk.
/// - tokenizerSource: Optional alternate source for the tokenizer.
public init(
directory: URL,
tokenizerSource: TokenizerSource? = nil
) {
self.id = .directory(directory)
self.tokenizerSource = tokenizerSource
}
// MARK: - Registry Management
/// Global registry of available model configurations.
@MainActor
public static var registry = [String: ModelConfiguration]()
/// Registers an array of configurations into the global registry.
/// - Parameter configurations: The models to register.
@MainActor
public static func register(configurations: [ModelConfiguration]) {
bootstrap()
for c in configurations {
registry[c.name] = c
}
}
/// Retrieves a configuration by its ID or name.
///
/// If the ID is not found in the registry, a new `ModelConfiguration` is
/// created on-the-fly using the provided string as a Hub ID.
///
/// - Parameter id: The model name or Hub ID.
/// - Returns: A `ModelConfiguration` instance.
@MainActor
public static func configuration(id: String) -> ModelConfiguration {
bootstrap()
if let c = registry[id] {
return c
} else {
return ModelConfiguration(id: id)
}
}
/// Returns all registered model configurations.
@MainActor
public static var models: some Collection<ModelConfiguration> & Sendable {
bootstrap()
return Self.registry.values
}
}
// MARK: - Predefined Models
extension ModelConfiguration {
/// BGE Micro v2 (TaylorAI) - optimized for extremely low latency.
public static let bge_micro = ModelConfiguration(id: "TaylorAI/bge-micro-v2")
/// GTE Tiny - a small, efficient embedding model.
public static let gte_tiny = ModelConfiguration(id: "TaylorAI/gte-tiny")
/// MiniLM-L6 - the industry-standard small embedding model.
public static let minilm_l6 = ModelConfiguration(id: "sentence-transformers/all-MiniLM-L6-v2")
/// Snowflake Arctic Embed XS.
public static let snowflake_xs = ModelConfiguration(id: "Snowflake/snowflake-arctic-embed-xs")
/// MiniLM-L12 - a more accurate version of MiniLM.
public static let minilm_l12 = ModelConfiguration(id: "sentence-transformers/all-MiniLM-L12-v2")
/// BGE Small en v1.5.
public static let bge_small = ModelConfiguration(id: "BAAI/bge-small-en-v1.5")
/// Multilingual E5 Small - supports over 100 languages.
public static let multilingual_e5_small = ModelConfiguration(
id: "intfloat/multilingual-e5-small")
/// BGE Base en v1.5.
public static let bge_base = ModelConfiguration(id: "BAAI/bge-base-en-v1.5")
/// Nomic Embed Text v1.
public static let nomic_text_v1 = ModelConfiguration(id: "nomic-ai/nomic-embed-text-v1")
/// Nomic Embed Text v1.5 - supports Matryoshka embeddings.
public static let nomic_text_v1_5 = ModelConfiguration(id: "nomic-ai/nomic-embed-text-v1.5")
/// BGE Large en v1.5.
public static let bge_large = ModelConfiguration(id: "BAAI/bge-large-en-v1.5")
/// Snowflake Arctic Embed L.
public static let snowflake_lg = ModelConfiguration(id: "Snowflake/snowflake-arctic-embed-l")
/// BGE-M3 - Multi-lingual, Multi-functional, Multi-granularity.
public static let bge_m3 = ModelConfiguration(id: "BAAI/bge-m3")
/// Mixedbread AI Large v1.
public static let mixedbread_large = ModelConfiguration(
id: "mixedbread-ai/mxbai-embed-large-v1")
/// Qwen3 Embedding 0.6B - 4-bit quantized version.
public static let qwen3_embedding = ModelConfiguration(
id: "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ")
private enum BootstrapState: Sendable {
case idle
case bootstrapping
case bootstrapped
}
/// Internal state to ensure the registry is only populated once.
@MainActor
static private var bootstrapState = BootstrapState.idle
/// Populates the registry with default models if it hasn't been done already.
@MainActor
static func bootstrap() {
switch bootstrapState {
case .idle:
bootstrapState = .bootstrapping
register(configurations: [
bge_micro, gte_tiny, minilm_l6, snowflake_xs, minilm_l12,
bge_small, multilingual_e5_small, bge_base, nomic_text_v1,
nomic_text_v1_5, bge_large, snowflake_lg, bge_m3,
mixedbread_large, qwen3_embedding,
])
bootstrapState = .bootstrapped
case .bootstrapping, .bootstrapped:
break
}
}
}