Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions Libraries/MLXEmbedders/EmbeddingModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,11 @@ public actor ModelContainer {
modelDirectory: URL,
configuration: ModelConfiguration
) async throws {
// Load tokenizer config and model in parallel using async let.
async let tokenizerConfigTask = loadTokenizerConfig(
configuration: configuration, hub: hub)

async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
self.model = try loadSynchronous(
modelDirectory: modelDirectory, modelName: configuration.name)
self.pooler = loadPooling(modelDirectory: modelDirectory)

let (tokenizerConfig, tokenizerData) = try await tokenizerConfigTask
self.tokenizer = try PreTrainedTokenizer(
tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
self.tokenizer = try await tokenizerTask
}

/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as
Expand Down
5 changes: 3 additions & 2 deletions Libraries/MLXEmbedders/Load.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,12 @@ func prepareModelDirectory(
) async throws -> URL {
do {
switch configuration.id {
case .id(let id):
case .id(let id, let revision):
let repo = Hub.Repo(id: id)
let modelFiles = ["*.safetensors", "config.json", "*/config.json"]
return try await hub.snapshot(
from: repo, matching: modelFiles, progressHandler: progressHandler)
from: repo, revision: revision, matching: modelFiles,
progressHandler: progressHandler)

case .directory(let directory):
return directory
Expand Down
10 changes: 5 additions & 5 deletions Libraries/MLXEmbedders/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public struct ModelConfiguration: Sendable {
/// The backing storage for the model's location.
public enum Identifier: Sendable {
/// A Hugging Face Hub repository identifier (e.g., "BAAI/bge-small-en-v1.5").
case id(String)
case id(String, revision: String = "main")
/// A file system URL pointing to a local model directory.
case directory(URL)
}
Expand All @@ -36,7 +36,7 @@ public struct ModelConfiguration: Sendable {
/// it returns a path-based name (e.g., "ParentDir/ModelDir").
public var name: String {
switch id {
case .id(let string):
case .id(let string, _):
string
case .directory(let url):
url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent
Expand All @@ -60,11 +60,11 @@ public struct ModelConfiguration: Sendable {
/// - tokenizerId: Optional alternate repo for the tokenizer.
/// - overrideTokenizer: Optional specific tokenizer implementation name.
public init(
id: String,
id: String, revision: String = "main",
tokenizerId: String? = nil,
overrideTokenizer: String? = nil
) {
self.id = .id(id)
self.id = .id(id, revision: revision)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
}
Expand All @@ -90,7 +90,7 @@ public struct ModelConfiguration: Sendable {
/// - Returns: A `URL` pointing to the local directory.
public func modelDirectory(hub: HubApi = HubApi()) -> URL {
switch id {
case .id(let id):
case .id(let id, _):
let repo = Hub.Repo(id: id)
return hub.localRepoLocation(repo)

Expand Down
66 changes: 7 additions & 59 deletions Libraries/MLXEmbedders/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,66 +18,14 @@ import Tokenizers
/// or standard network/parsing errors.
public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer
{
let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig(
configuration: configuration, hub: hub)

return try PreTrainedTokenizer(
tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}

/// Retrieves the raw configuration and data files required to build a tokenizer.
///
/// This internal helper handles the logic of determining where to fetch files from.
/// It includes a robust fallback: if a network request fails due to lack of internet
/// connectivity, it attempts to load the files from the local model directory.
///
/// - Parameters:
/// - configuration: The model configuration providing the `tokenizerId` or `modelDirectory`.
/// - hub: The `HubApi` interface for remote or local file resolution.
/// - Returns: A tuple containing the `tokenizerConfig` and `tokenizerData` configurations.
/// - Throws: `NSURLError` for network issues (other than offline status).
/// - Throws: `EmbedderError.missingTokenizerConfig` if the configuration files are
/// successfully accessed but do not contain a valid `tokenizerConfig` payload.
/// This typically occurs when the model repository or directory is missing a
/// `tokenizer_config.json` file.
func loadTokenizerConfig(
configuration: ModelConfiguration,
hub: HubApi
) async throws -> (Config, Config) {
// from AutoTokenizer.from() -- this lets us override parts of the configuration
let config: LanguageModelConfigurationFromHub

switch configuration.id {
case .id(let id):
do {
// Attempt to load from the remote Hub or Hub cache
let loaded = LanguageModelConfigurationFromHub(
modelName: configuration.tokenizerId ?? id, hubApi: hub)

// Trigger an async fetch to verify the config exists
_ = try await loaded.tokenizerConfig
config = loaded
} catch {
let nserror = error as NSError
if nserror.domain == NSURLErrorDomain
&& nserror.code == NSURLErrorNotConnectedToInternet
{
// Fallback: Internet connection is offline, load from the local model directory
config = LanguageModelConfigurationFromHub(
modelFolder: configuration.modelDirectory(hub: hub), hubApi: hub)
} else {
// Re-throw if it's a critical error (e.g., 404, parsing error)
throw error
}
}
case .id(let id, let revision):
return try await AutoTokenizer.from(
pretrained: configuration.tokenizerId ?? id,
hubApi: hub,
revision: revision
)
case .directory(let directory):
// Load directly from a specified local directory
config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub)
}

guard let tokenizerConfig = try await config.tokenizerConfig else {
throw EmbedderError.missingTokenizerConfig
return try await AutoTokenizer.from(modelFolder: directory, hubApi: hub)
}
let tokenizerData = try await config.tokenizerData
return (tokenizerConfig, tokenizerData)
}
10 changes: 1 addition & 9 deletions Libraries/MLXLLM/LLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {

static public let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }"
)

Expand All @@ -129,28 +128,22 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
id: "mlx-community/Phi-3.5-MoE-instruct-4bit",
defaultPrompt: "What is the gravity on Mars and the moon?",
extraEOSTokens: ["<|end|>"]
) {
prompt in
"<|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
}
)

static public let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer",
// https://www.promptingguide.ai/models/gemma
defaultPrompt: "what is the difference between lettuce and cabbage?"
)

static public let gemma_2_9b_it_4bit = ModelConfiguration(
id: "mlx-community/gemma-2-9b-it-4bit",
overrideTokenizer: "PreTrainedTokenizer",
// https://www.promptingguide.ai/models/gemma
defaultPrompt: "What is the difference between lettuce and cabbage?"
)

static public let gemma_2_2b_it_4bit = ModelConfiguration(
id: "mlx-community/gemma-2-2b-it-4bit",
overrideTokenizer: "PreTrainedTokenizer",
// https://www.promptingguide.ai/models/gemma
defaultPrompt: "What is the difference between lettuce and cabbage?"
)
Expand Down Expand Up @@ -191,7 +184,6 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {

static public let qwen205b4bit = ModelConfiguration(
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "why is the sky blue?"
)

Expand Down
16 changes: 5 additions & 11 deletions Libraries/MLXLMCommon/ModelConfiguration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@ public struct ModelConfiguration: Sendable {
}
}

/// pull the tokenizer from an alternate id
/// Alternate repo ID to use for the tokenizer
public let tokenizerId: String?

/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
public let overrideTokenizer: String?

/// A reasonable default prompt for the model
public var defaultPrompt: String

Expand All @@ -44,31 +41,28 @@ public struct ModelConfiguration: Sendable {

public init(
id: String, revision: String = "main",
tokenizerId: String? = nil, overrideTokenizer: String? = nil,
tokenizerId: String? = nil,
defaultPrompt: String = "hello",
extraEOSTokens: Set<String> = [],
toolCallFormat: ToolCallFormat? = nil,
preparePrompt: (@Sendable (String) -> String)? = nil
toolCallFormat: ToolCallFormat? = nil
) {
self.id = .id(id, revision: revision)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.toolCallFormat = toolCallFormat
}

public init(
directory: URL,
tokenizerId: String? = nil, overrideTokenizer: String? = nil,
tokenizerId: String? = nil,
defaultPrompt: String = "hello",
extraEOSTokens: Set<String> = [],
eosTokenIds: Set<Int> = [],
toolCallFormat: ToolCallFormat? = nil
) {
self.id = .directory(directory)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.eosTokenIds = eosTokenIds
Expand All @@ -78,7 +72,7 @@ public struct ModelConfiguration: Sendable {
public func modelDirectory(hub: HubApi = HubApi()) -> URL {
switch id {
case .id(let id, _):
// download the model weights and config
// Download the model weights and config
let repo = Hub.Repo(id: id)
return hub.localRepoLocation(repo)

Expand Down
10 changes: 5 additions & 5 deletions Libraries/MLXLMCommon/ModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,15 @@ public func loadModelContainer(
}
}

/// Load a model given a huggingface identifier.
/// Load a model given a Hugging Face identifier.
///
/// This will load and return a ``ModelContext``. This holds the model and tokenzier without
/// This will load and return a ``ModelContext``. This holds the model and tokenizer without
/// an `actor` providing an isolation context. Use this call when you control the isolation context
/// and can hold the ``ModelContext`` directly.
///
/// - Parameters:
/// - hub: optional HubApi -- by default uses ``defaultHubApi``
/// - id: huggingface model identifier, e.g "mlx-community/Qwen3-4B-4bit"
/// - id: Hugging Face model identifier, e.g "mlx-community/Qwen3-4B-4bit"
/// - progressHandler: optional callback for progress
/// - Returns: a ``ModelContext``
public func loadModel(
Expand All @@ -229,14 +229,14 @@ public func loadModel(
}
}

/// Load a model given a huggingface identifier.
/// Load a model given a Hugging Face identifier.
///
/// This will load and return a ``ModelContainer``. This holds a ``ModelContext``
/// inside an actor providing isolation control for the values.
///
/// - Parameters:
/// - hub: optional HubApi -- by default uses ``defaultHubApi``
/// - id: huggingface model identifier, e.g "mlx-community/Qwen3-4B-4bit"
/// - id: Hugging Face model identifier, e.g "mlx-community/Qwen3-4B-4bit"
/// - progressHandler: optional callback for progress
/// - Returns: a ``ModelContainer``
public func loadModelContainer(
Expand Down
74 changes: 23 additions & 51 deletions Libraries/MLXLMCommon/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,21 @@ struct TokenizerError: Error {

public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer
{
let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig(
configuration: configuration, hub: hub)

return try PreTrainedTokenizer(
tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
switch configuration.id {
case .id(let id, let revision):
return try await AutoTokenizer.from(
pretrained: configuration.tokenizerId ?? id,
hubApi: hub,
revision: revision
)
case .directory(let directory):
return try await AutoTokenizer.from(modelFolder: directory, hubApi: hub)
}
}

@available(
*, deprecated, message: "Use LanguageModelConfigurationFromHub from swift-transformers directly"
)
public func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async throws -> (
Config, Config
) {
Expand Down Expand Up @@ -48,60 +56,24 @@ public func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi)
config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub)
}

guard var tokenizerConfig = try await config.tokenizerConfig else {
guard let tokenizerConfig = try await config.tokenizerConfig else {
throw TokenizerError(message: "missing config")
}
let tokenizerData = try await config.tokenizerData

tokenizerConfig = updateTokenizerConfig(tokenizerConfig)

return (tokenizerConfig, tokenizerData)
}

private func updateTokenizerConfig(_ tokenizerConfig: Config) -> Config {
// Workaround: replacement tokenizers for unhandled values in swift-transformers
if let tokenizerClass = tokenizerConfig.tokenizerClass?.string(),
let replacement = replacementTokenizers[tokenizerClass]
{
if var dictionary = tokenizerConfig.dictionary() {
dictionary["tokenizer_class"] = .init(replacement)
return Config(dictionary)
}
}
return tokenizerConfig
}

public class TokenizerReplacementRegistry: @unchecked Sendable {

// Note: using NSLock as we have very small (just dictionary get/set)
// critical sections and expect no contention. this allows the methods
// to remain synchronous.
private let lock = NSLock()

/// overrides for TokenizerModel/knownTokenizers
private var replacementTokenizers = [
"InternLM2Tokenizer": "PreTrainedTokenizer",
"Qwen2Tokenizer": "PreTrainedTokenizer",
"Qwen3Tokenizer": "PreTrainedTokenizer",
"CohereTokenizer": "PreTrainedTokenizer",
"GPTNeoXTokenizer": "PreTrainedTokenizer",
"TokenizersBackend": "PreTrainedTokenizer",
]

public subscript(key: String) -> String? {
get {
lock.withLock {
replacementTokenizers[key]
}
}
set {
lock.withLock {
replacementTokenizers[key] = newValue
}
}
}
}
@available(
*, unavailable,
message: "Use AutoTokenizer.register(_:for:) from swift-transformers instead"
)
public class TokenizerReplacementRegistry: @unchecked Sendable {}

@available(
*, unavailable,
message: "Use AutoTokenizer.register(_:for:) from swift-transformers instead"
)
public let replacementTokenizers = TokenizerReplacementRegistry()

public protocol StreamingDetokenizer: IteratorProtocol<String> {
Expand Down
2 changes: 1 addition & 1 deletion Libraries/MLXVLM/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ media as needed. For example it might:
- modify the prompt by injecting `<image>` tokens that the model expects

In the python implementations, much of this code typically lives in the `transformers`
package from huggingface -- inspection will be required to determine which code
package from Hugging Face -- inspection will be required to determine which code
is called and what it does. You can examine the processors in the `Models` directory:
they reference the files and functions that they are based on.

Expand Down
Loading