Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 24 additions & 42 deletions Libraries/Embedders/Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,65 +33,47 @@ private class ModelTypeRegistry: @unchecked Sendable {
// to remain synchronous.
private let lock = NSLock()

private var creators: [String: @Sendable (URL) throws -> EmbeddingModel] = [
"bert": {
url in
let configuration = try JSONDecoder().decode(
BertConfiguration.self, from: Data(contentsOf: url))
let model = BertModel(configuration)
return model
private var creators: [String: @Sendable (Data) throws -> EmbeddingModel] = [
"bert": { data in
let configuration = try JSONDecoder().decode(BertConfiguration.self, from: data)
return BertModel(configuration)
},
"roberta": {
url in
let configuration = try JSONDecoder().decode(
BertConfiguration.self, from: Data(contentsOf: url))
let model = BertModel(configuration)
return model
"roberta": { data in
let configuration = try JSONDecoder().decode(BertConfiguration.self, from: data)
return BertModel(configuration)
},
"xlm-roberta": {
url in
let configuration = try JSONDecoder().decode(
BertConfiguration.self, from: Data(contentsOf: url))
let model = BertModel(configuration)
return model
"xlm-roberta": { data in
let configuration = try JSONDecoder().decode(BertConfiguration.self, from: data)
return BertModel(configuration)
},
"distilbert": {
url in
let configuration = try JSONDecoder().decode(
BertConfiguration.self, from: Data(contentsOf: url))
let model = BertModel(configuration)
return model
"distilbert": { data in
let configuration = try JSONDecoder().decode(BertConfiguration.self, from: data)
return BertModel(configuration)
},
"nomic_bert": {
url in
let configuration = try JSONDecoder().decode(
NomicBertConfiguration.self, from: Data(contentsOf: url))
let model = NomicBertModel(configuration, pooler: false)
return model
"nomic_bert": { data in
let configuration = try JSONDecoder().decode(NomicBertConfiguration.self, from: data)
return NomicBertModel(configuration, pooler: false)
},
"qwen3": {
url in
let configuration = try JSONDecoder().decode(
Qwen3Configuration.self, from: Data(contentsOf: url))
let model = Qwen3Model(configuration)
return model
"qwen3": { data in
let configuration = try JSONDecoder().decode(Qwen3Configuration.self, from: data)
return Qwen3Model(configuration)
},
]

public func registerModelType(
_ type: String, creator: @Sendable @escaping (URL) throws -> EmbeddingModel
_ type: String, creator: @Sendable @escaping (Data) throws -> EmbeddingModel
) {
lock.withLock {
creators[type] = creator
}
}

public func createModel(configuration: URL, rawValue: String) throws -> EmbeddingModel {
public func createModel(configuration: Data, rawValue: String) throws -> EmbeddingModel {
let creator = lock.withLock {
creators[rawValue]
}
guard let creator else {
throw EmbedderError(message: "Unsupported model type.")
throw EmbedderError.unsupportedModelType(rawValue)
}
return try creator(configuration)
}
Expand All @@ -108,12 +90,12 @@ public struct ModelType: RawRepresentable, Codable, Sendable {
}

public static func registerModelType(
_ type: String, creator: @Sendable @escaping (URL) throws -> EmbeddingModel
_ type: String, creator: @Sendable @escaping (Data) throws -> EmbeddingModel
) {
modelTypeRegistry.registerModelType(type, creator: creator)
}

public func createModel(configuration: URL) throws -> EmbeddingModel {
public func createModel(configuration: Data) throws -> EmbeddingModel {
try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue)
}
}
12 changes: 8 additions & 4 deletions Libraries/Embedders/EmbeddingModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,17 @@ public actor ModelContainer {
public init(
hub: HubApi, modelDirectory: URL, configuration: ModelConfiguration
) async throws {
self.model = try loadSynchronous(modelDirectory: modelDirectory)

let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig(
// Load tokenizer config and model in parallel using async let.
async let tokenizerConfigTask = loadTokenizerConfig(
configuration: configuration, hub: hub)

self.model = try loadSynchronous(
modelDirectory: modelDirectory, modelName: configuration.name)
self.pooler = loadPooling(modelDirectory: modelDirectory)

let (tokenizerConfig, tokenizerData) = try await tokenizerConfigTask
self.tokenizer = try PreTrainedTokenizer(
tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
self.pooler = loadPooling(modelDirectory: modelDirectory) //?? Pooling(strategy: .none)
}

/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as
Expand Down
79 changes: 70 additions & 9 deletions Libraries/Embedders/Load.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,48 @@ import MLX
import MLXNN
import Tokenizers

struct EmbedderError: Error {
let message: String
public enum EmbedderError: LocalizedError {
case unsupportedModelType(String)
case configurationFileError(String, String, Error)
case configurationDecodingError(String, String, DecodingError)
case missingTokenizerConfig

public var errorDescription: String? {
switch self {
case .unsupportedModelType(let type):
return "Unsupported model type: \(type)"
case .configurationFileError(let file, let modelName, let error):
return "Error reading '\(file)' for model '\(modelName)': \(error.localizedDescription)"
case .configurationDecodingError(let file, let modelName, let decodingError):
let errorDetail = extractDecodingErrorDetail(decodingError)
return "Failed to parse \(file) for model '\(modelName)': \(errorDetail)"
case .missingTokenizerConfig:
return "Missing tokenizer configuration"
}
}

private func extractDecodingErrorDetail(_ error: DecodingError) -> String {
switch error {
case .keyNotFound(let key, let context):
let path = (context.codingPath + [key]).map { $0.stringValue }.joined(separator: ".")
return "Missing field '\(path)'"
case .typeMismatch(_, let context):
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
return "Type mismatch at '\(path)'"
case .valueNotFound(_, let context):
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
return "Missing value at '\(path)'"
case .dataCorrupted(let context):
if context.codingPath.isEmpty {
return "Invalid JSON"
} else {
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
return "Invalid data at '\(path)'"
}
@unknown default:
return error.localizedDescription
}
}
}

func prepareModelDirectory(
Expand Down Expand Up @@ -49,20 +89,41 @@ public func load(
) async throws -> (EmbeddingModel, Tokenizer) {
let modelDirectory = try await prepareModelDirectory(
hub: hub, configuration: configuration, progressHandler: progressHandler)
let model = try loadSynchronous(modelDirectory: modelDirectory)
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)

// Load tokenizer and model in parallel using async let.
async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
let model = try loadSynchronous(modelDirectory: modelDirectory, modelName: configuration.name)
let tokenizer = try await tokenizerTask

return (model, tokenizer)
}

func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
// create the model (no weights loaded)
func loadSynchronous(modelDirectory: URL, modelName: String) throws -> EmbeddingModel {
// Load config.json once and decode for both base config and model-specific config
let configurationURL = modelDirectory.appending(component: "config.json")
let baseConfig = try JSONDecoder().decode(
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
let configData: Data
do {
configData = try Data(contentsOf: configurationURL)
} catch {
throw EmbedderError.configurationFileError(
configurationURL.lastPathComponent, modelName, error)
}
let baseConfig: BaseConfiguration
do {
baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
} catch let error as DecodingError {
throw EmbedderError.configurationDecodingError(
configurationURL.lastPathComponent, modelName, error)
}

let modelType = ModelType(rawValue: baseConfig.modelType)
let model = try modelType.createModel(configuration: configurationURL)
let model: EmbeddingModel
do {
model = try modelType.createModel(configuration: configData)
} catch let error as DecodingError {
throw EmbedderError.configurationDecodingError(
configurationURL.lastPathComponent, modelName, error)
}

// load the weights
var weights = [String: MLXArray]()
Expand Down
2 changes: 1 addition & 1 deletion Libraries/Embedders/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async t
}

guard let tokenizerConfig = try await config.tokenizerConfig else {
throw EmbedderError(message: "missing config")
throw EmbedderError.missingTokenizerConfig
}
let tokenizerData = try await config.tokenizerData
return (tokenizerConfig, tokenizerData)
Expand Down
30 changes: 18 additions & 12 deletions Libraries/MLXLLM/LLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ import MLX
import MLXLMCommon
import Tokenizers

/// Creates a function that loads a configuration file and instantiates a model with the proper configuration
/// Creates a function that decodes configuration data and instantiates a model with the proper configuration
private func create<C: Codable, M>(
_ configurationType: C.Type, _ modelInit: @escaping (C) -> M
) -> (URL) throws -> M {
{ url in
let configuration = try JSONDecoder().decode(
C.self, from: Data(contentsOf: url))
) -> (Data) throws -> M {
{ data in
let configuration = try JSONDecoder().decode(C.self, from: data)
return modelInit(configuration)
}
}
Expand Down Expand Up @@ -478,13 +477,18 @@ public final class LLMModelFactory: ModelFactory {
let modelDirectory = try await downloadModel(
hub: hub, configuration: configuration, progressHandler: progressHandler)

// Load the generic config to understand which model and how to load the weights
// Load config.json once and decode for both base config and model-specific config
let configurationURL = modelDirectory.appending(component: "config.json")

let configData: Data
do {
configData = try Data(contentsOf: configurationURL)
} catch {
throw ModelFactoryError.configurationFileError(
configurationURL.lastPathComponent, configuration.name, error)
}
let baseConfig: BaseConfiguration
do {
baseConfig = try JSONDecoder().decode(
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
} catch let error as DecodingError {
throw ModelFactoryError.configurationDecodingError(
configurationURL.lastPathComponent, configuration.name, error)
Expand All @@ -493,18 +497,20 @@ public final class LLMModelFactory: ModelFactory {
let model: LanguageModel
do {
model = try await typeRegistry.createModel(
configuration: configurationURL, modelType: baseConfig.modelType)
configuration: configData, modelType: baseConfig.modelType)
} catch let error as DecodingError {
throw ModelFactoryError.configurationDecodingError(
configurationURL.lastPathComponent, configuration.name, error)
}

// apply the weights to the bare model
// Load tokenizer and weights in parallel using async let.
async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)

try loadWeights(
modelDirectory: modelDirectory, model: model,
perLayerQuantization: baseConfig.perLayerQuantization)

let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
let tokenizer = try await tokenizerTask

let messageGenerator =
if let model = model as? LLMModel {
Expand Down
3 changes: 3 additions & 0 deletions Libraries/MLXLMCommon/ModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import Tokenizers
public enum ModelFactoryError: LocalizedError {
case unsupportedModelType(String)
case unsupportedProcessorType(String)
case configurationFileError(String, String, Error)
case configurationDecodingError(String, String, DecodingError)
case noModelFactoryAvailable

Expand All @@ -16,6 +17,8 @@ public enum ModelFactoryError: LocalizedError {
return "Unsupported model type: \(type)"
case .unsupportedProcessorType(let type):
return "Unsupported processor type: \(type)"
case .configurationFileError(let file, let modelName, let error):
return "Error reading '\(file)' for model '\(modelName)': \(error.localizedDescription)"
case .noModelFactoryAvailable:
return "No model factory available via ModelFactoryRegistry"
case .configurationDecodingError(let file, let modelName, let decodingError):
Expand Down
11 changes: 6 additions & 5 deletions Libraries/MLXLMCommon/Registries/ModelTypeRegistry.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,22 @@ public actor ModelTypeRegistry {
}

/// Creates a registry with given creators.
public init(creators: [String: (URL) throws -> any LanguageModel]) {
public init(creators: [String: (Data) throws -> any LanguageModel]) {
self.creators = creators
}

private var creators: [String: (URL) throws -> any LanguageModel]
private var creators: [String: (Data) throws -> any LanguageModel]

/// Add a new model to the type registry.
public func registerModelType(
_ type: String, creator: @escaping (URL) throws -> any LanguageModel
_ type: String, creator: @escaping (Data) throws -> any LanguageModel
) {
creators[type] = creator
}

/// Given a `modelType` and configuration file instantiate a new `LanguageModel`.
public func createModel(configuration: URL, modelType: String) throws -> sending LanguageModel {
/// Given a `modelType` and configuration data instantiate a new `LanguageModel`.
public func createModel(configuration: Data, modelType: String) throws -> sending LanguageModel
{
guard let creator = creators[modelType] else {
throw ModelFactoryError.unsupportedModelType(modelType)
}
Expand Down
10 changes: 5 additions & 5 deletions Libraries/MLXLMCommon/Registries/ProcessorTypeRegistry.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,26 @@ public actor ProcessorTypeRegistry {
}

/// Creates a registry with given creators.
public init(creators: [String: (URL, any Tokenizer) throws -> any UserInputProcessor]) {
public init(creators: [String: (Data, any Tokenizer) throws -> any UserInputProcessor]) {
self.creators = creators
}

private var creators: [String: (URL, any Tokenizer) throws -> any UserInputProcessor]
private var creators: [String: (Data, any Tokenizer) throws -> any UserInputProcessor]

/// Add a new model to the type registry.
public func registerProcessorType(
_ type: String,
creator:
@escaping (
URL,
Data,
any Tokenizer
) throws -> any UserInputProcessor
) {
creators[type] = creator
}

/// Given a `processorType` and configuration file instantiate a new `UserInputProcessor`.
public func createModel(configuration: URL, processorType: String, tokenizer: any Tokenizer)
/// Given a `processorType` and configuration data instantiate a new `UserInputProcessor`.
public func createModel(configuration: Data, processorType: String, tokenizer: any Tokenizer)
throws -> sending any UserInputProcessor
{
guard let creator = creators[processorType] else {
Expand Down
Loading