Skip to content

Commit 27a2f21

Browse files
Optimize model loading performance (#34)
* Add model loading benchmarks * Parallelize loading of weights, tokenizer, and processor config * Improve error handling * Clarify parallelism in comments
1 parent 5d89cc9 commit 27a2f21

File tree

11 files changed

+322
-117
lines changed

11 files changed

+322
-117
lines changed

Libraries/Embedders/Configuration.swift

Lines changed: 24 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -33,65 +33,47 @@ private class ModelTypeRegistry: @unchecked Sendable {
3333
// to remain synchronous.
3434
private let lock = NSLock()
3535

36-
private var creators: [String: @Sendable (URL) throws -> EmbeddingModel] = [
37-
"bert": {
38-
url in
39-
let configuration = try JSONDecoder().decode(
40-
BertConfiguration.self, from: Data(contentsOf: url))
41-
let model = BertModel(configuration)
42-
return model
36+
private var creators: [String: @Sendable (Data) throws -> EmbeddingModel] = [
37+
"bert": { data in
38+
let configuration = try JSONDecoder().decode(BertConfiguration.self, from: data)
39+
return BertModel(configuration)
4340
},
44-
"roberta": {
45-
url in
46-
let configuration = try JSONDecoder().decode(
47-
BertConfiguration.self, from: Data(contentsOf: url))
48-
let model = BertModel(configuration)
49-
return model
41+
"roberta": { data in
42+
let configuration = try JSONDecoder().decode(BertConfiguration.self, from: data)
43+
return BertModel(configuration)
5044
},
51-
"xlm-roberta": {
52-
url in
53-
let configuration = try JSONDecoder().decode(
54-
BertConfiguration.self, from: Data(contentsOf: url))
55-
let model = BertModel(configuration)
56-
return model
45+
"xlm-roberta": { data in
46+
let configuration = try JSONDecoder().decode(BertConfiguration.self, from: data)
47+
return BertModel(configuration)
5748
},
58-
"distilbert": {
59-
url in
60-
let configuration = try JSONDecoder().decode(
61-
BertConfiguration.self, from: Data(contentsOf: url))
62-
let model = BertModel(configuration)
63-
return model
49+
"distilbert": { data in
50+
let configuration = try JSONDecoder().decode(BertConfiguration.self, from: data)
51+
return BertModel(configuration)
6452
},
65-
"nomic_bert": {
66-
url in
67-
let configuration = try JSONDecoder().decode(
68-
NomicBertConfiguration.self, from: Data(contentsOf: url))
69-
let model = NomicBertModel(configuration, pooler: false)
70-
return model
53+
"nomic_bert": { data in
54+
let configuration = try JSONDecoder().decode(NomicBertConfiguration.self, from: data)
55+
return NomicBertModel(configuration, pooler: false)
7156
},
72-
"qwen3": {
73-
url in
74-
let configuration = try JSONDecoder().decode(
75-
Qwen3Configuration.self, from: Data(contentsOf: url))
76-
let model = Qwen3Model(configuration)
77-
return model
57+
"qwen3": { data in
58+
let configuration = try JSONDecoder().decode(Qwen3Configuration.self, from: data)
59+
return Qwen3Model(configuration)
7860
},
7961
]
8062

8163
public func registerModelType(
82-
_ type: String, creator: @Sendable @escaping (URL) throws -> EmbeddingModel
64+
_ type: String, creator: @Sendable @escaping (Data) throws -> EmbeddingModel
8365
) {
8466
lock.withLock {
8567
creators[type] = creator
8668
}
8769
}
8870

89-
public func createModel(configuration: URL, rawValue: String) throws -> EmbeddingModel {
71+
public func createModel(configuration: Data, rawValue: String) throws -> EmbeddingModel {
9072
let creator = lock.withLock {
9173
creators[rawValue]
9274
}
9375
guard let creator else {
94-
throw EmbedderError(message: "Unsupported model type.")
76+
throw EmbedderError.unsupportedModelType(rawValue)
9577
}
9678
return try creator(configuration)
9779
}
@@ -108,12 +90,12 @@ public struct ModelType: RawRepresentable, Codable, Sendable {
10890
}
10991

11092
public static func registerModelType(
111-
_ type: String, creator: @Sendable @escaping (URL) throws -> EmbeddingModel
93+
_ type: String, creator: @Sendable @escaping (Data) throws -> EmbeddingModel
11294
) {
11395
modelTypeRegistry.registerModelType(type, creator: creator)
11496
}
11597

116-
public func createModel(configuration: URL) throws -> EmbeddingModel {
98+
public func createModel(configuration: Data) throws -> EmbeddingModel {
11799
try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue)
118100
}
119101
}

Libraries/Embedders/EmbeddingModel.swift

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,17 @@ public actor ModelContainer {
4646
public init(
4747
hub: HubApi, modelDirectory: URL, configuration: ModelConfiguration
4848
) async throws {
49-
self.model = try loadSynchronous(modelDirectory: modelDirectory)
50-
51-
let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig(
49+
// Load tokenizer config and model in parallel using async let.
50+
async let tokenizerConfigTask = loadTokenizerConfig(
5251
configuration: configuration, hub: hub)
52+
53+
self.model = try loadSynchronous(
54+
modelDirectory: modelDirectory, modelName: configuration.name)
55+
self.pooler = loadPooling(modelDirectory: modelDirectory)
56+
57+
let (tokenizerConfig, tokenizerData) = try await tokenizerConfigTask
5358
self.tokenizer = try PreTrainedTokenizer(
5459
tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
55-
self.pooler = loadPooling(modelDirectory: modelDirectory) //?? Pooling(strategy: .none)
5660
}
5761

5862
/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as

Libraries/Embedders/Load.swift

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,48 @@ import MLX
66
import MLXNN
77
import Tokenizers
88

9-
struct EmbedderError: Error {
10-
let message: String
9+
public enum EmbedderError: LocalizedError {
10+
case unsupportedModelType(String)
11+
case configurationFileError(String, String, Error)
12+
case configurationDecodingError(String, String, DecodingError)
13+
case missingTokenizerConfig
14+
15+
public var errorDescription: String? {
16+
switch self {
17+
case .unsupportedModelType(let type):
18+
return "Unsupported model type: \(type)"
19+
case .configurationFileError(let file, let modelName, let error):
20+
return "Error reading '\(file)' for model '\(modelName)': \(error.localizedDescription)"
21+
case .configurationDecodingError(let file, let modelName, let decodingError):
22+
let errorDetail = extractDecodingErrorDetail(decodingError)
23+
return "Failed to parse \(file) for model '\(modelName)': \(errorDetail)"
24+
case .missingTokenizerConfig:
25+
return "Missing tokenizer configuration"
26+
}
27+
}
28+
29+
private func extractDecodingErrorDetail(_ error: DecodingError) -> String {
30+
switch error {
31+
case .keyNotFound(let key, let context):
32+
let path = (context.codingPath + [key]).map { $0.stringValue }.joined(separator: ".")
33+
return "Missing field '\(path)'"
34+
case .typeMismatch(_, let context):
35+
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
36+
return "Type mismatch at '\(path)'"
37+
case .valueNotFound(_, let context):
38+
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
39+
return "Missing value at '\(path)'"
40+
case .dataCorrupted(let context):
41+
if context.codingPath.isEmpty {
42+
return "Invalid JSON"
43+
} else {
44+
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
45+
return "Invalid data at '\(path)'"
46+
}
47+
@unknown default:
48+
return error.localizedDescription
49+
}
50+
}
1151
}
1252

1353
func prepareModelDirectory(
@@ -49,20 +89,41 @@ public func load(
4989
) async throws -> (EmbeddingModel, Tokenizer) {
5090
let modelDirectory = try await prepareModelDirectory(
5191
hub: hub, configuration: configuration, progressHandler: progressHandler)
52-
let model = try loadSynchronous(modelDirectory: modelDirectory)
53-
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
92+
93+
// Load tokenizer and model in parallel using async let.
94+
async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
95+
let model = try loadSynchronous(modelDirectory: modelDirectory, modelName: configuration.name)
96+
let tokenizer = try await tokenizerTask
5497

5598
return (model, tokenizer)
5699
}
57100

58-
func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
59-
// create the model (no weights loaded)
101+
func loadSynchronous(modelDirectory: URL, modelName: String) throws -> EmbeddingModel {
102+
// Load config.json once and decode for both base config and model-specific config
60103
let configurationURL = modelDirectory.appending(component: "config.json")
61-
let baseConfig = try JSONDecoder().decode(
62-
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
104+
let configData: Data
105+
do {
106+
configData = try Data(contentsOf: configurationURL)
107+
} catch {
108+
throw EmbedderError.configurationFileError(
109+
configurationURL.lastPathComponent, modelName, error)
110+
}
111+
let baseConfig: BaseConfiguration
112+
do {
113+
baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
114+
} catch let error as DecodingError {
115+
throw EmbedderError.configurationDecodingError(
116+
configurationURL.lastPathComponent, modelName, error)
117+
}
63118

64119
let modelType = ModelType(rawValue: baseConfig.modelType)
65-
let model = try modelType.createModel(configuration: configurationURL)
120+
let model: EmbeddingModel
121+
do {
122+
model = try modelType.createModel(configuration: configData)
123+
} catch let error as DecodingError {
124+
throw EmbedderError.configurationDecodingError(
125+
configurationURL.lastPathComponent, modelName, error)
126+
}
66127

67128
// load the weights
68129
var weights = [String: MLXArray]()

Libraries/Embedders/Tokenizer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async t
4545
}
4646

4747
guard let tokenizerConfig = try await config.tokenizerConfig else {
48-
throw EmbedderError(message: "missing config")
48+
throw EmbedderError.missingTokenizerConfig
4949
}
5050
let tokenizerData = try await config.tokenizerData
5151
return (tokenizerConfig, tokenizerData)

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@ import MLX
66
import MLXLMCommon
77
import Tokenizers
88

9-
/// Creates a function that loads a configuration file and instantiates a model with the proper configuration
9+
/// Creates a function that decodes configuration data and instantiates a model with the proper configuration
1010
private func create<C: Codable, M>(
1111
_ configurationType: C.Type, _ modelInit: @escaping (C) -> M
12-
) -> (URL) throws -> M {
13-
{ url in
14-
let configuration = try JSONDecoder().decode(
15-
C.self, from: Data(contentsOf: url))
12+
) -> (Data) throws -> M {
13+
{ data in
14+
let configuration = try JSONDecoder().decode(C.self, from: data)
1615
return modelInit(configuration)
1716
}
1817
}
@@ -479,13 +478,18 @@ public final class LLMModelFactory: ModelFactory {
479478
let modelDirectory = try await downloadModel(
480479
hub: hub, configuration: configuration, progressHandler: progressHandler)
481480

482-
// Load the generic config to understand which model and how to load the weights
481+
// Load config.json once and decode for both base config and model-specific config
483482
let configurationURL = modelDirectory.appending(component: "config.json")
484-
483+
let configData: Data
484+
do {
485+
configData = try Data(contentsOf: configurationURL)
486+
} catch {
487+
throw ModelFactoryError.configurationFileError(
488+
configurationURL.lastPathComponent, configuration.name, error)
489+
}
485490
let baseConfig: BaseConfiguration
486491
do {
487-
baseConfig = try JSONDecoder().decode(
488-
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
492+
baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
489493
} catch let error as DecodingError {
490494
throw ModelFactoryError.configurationDecodingError(
491495
configurationURL.lastPathComponent, configuration.name, error)
@@ -494,18 +498,20 @@ public final class LLMModelFactory: ModelFactory {
494498
let model: LanguageModel
495499
do {
496500
model = try await typeRegistry.createModel(
497-
configuration: configurationURL, modelType: baseConfig.modelType)
501+
configuration: configData, modelType: baseConfig.modelType)
498502
} catch let error as DecodingError {
499503
throw ModelFactoryError.configurationDecodingError(
500504
configurationURL.lastPathComponent, configuration.name, error)
501505
}
502506

503-
// apply the weights to the bare model
507+
// Load tokenizer and weights in parallel using async let.
508+
async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
509+
504510
try loadWeights(
505511
modelDirectory: modelDirectory, model: model,
506512
perLayerQuantization: baseConfig.perLayerQuantization)
507513

508-
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
514+
let tokenizer = try await tokenizerTask
509515

510516
let messageGenerator =
511517
if let model = model as? LLMModel {

Libraries/MLXLMCommon/ModelFactory.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import Tokenizers
77
public enum ModelFactoryError: LocalizedError {
88
case unsupportedModelType(String)
99
case unsupportedProcessorType(String)
10+
case configurationFileError(String, String, Error)
1011
case configurationDecodingError(String, String, DecodingError)
1112
case noModelFactoryAvailable
1213

@@ -16,6 +17,8 @@ public enum ModelFactoryError: LocalizedError {
1617
return "Unsupported model type: \(type)"
1718
case .unsupportedProcessorType(let type):
1819
return "Unsupported processor type: \(type)"
20+
case .configurationFileError(let file, let modelName, let error):
21+
return "Error reading '\(file)' for model '\(modelName)': \(error.localizedDescription)"
1922
case .noModelFactoryAvailable:
2023
return "No model factory available via ModelFactoryRegistry"
2124
case .configurationDecodingError(let file, let modelName, let decodingError):

Libraries/MLXLMCommon/Registries/ModelTypeRegistry.swift

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,22 @@ public actor ModelTypeRegistry {
1010
}
1111

1212
/// Creates a registry with given creators.
13-
public init(creators: [String: (URL) throws -> any LanguageModel]) {
13+
public init(creators: [String: (Data) throws -> any LanguageModel]) {
1414
self.creators = creators
1515
}
1616

17-
private var creators: [String: (URL) throws -> any LanguageModel]
17+
private var creators: [String: (Data) throws -> any LanguageModel]
1818

1919
/// Add a new model to the type registry.
2020
public func registerModelType(
21-
_ type: String, creator: @escaping (URL) throws -> any LanguageModel
21+
_ type: String, creator: @escaping (Data) throws -> any LanguageModel
2222
) {
2323
creators[type] = creator
2424
}
2525

26-
/// Given a `modelType` and configuration file instantiate a new `LanguageModel`.
27-
public func createModel(configuration: URL, modelType: String) throws -> sending LanguageModel {
26+
/// Given a `modelType` and configuration data instantiate a new `LanguageModel`.
27+
public func createModel(configuration: Data, modelType: String) throws -> sending LanguageModel
28+
{
2829
guard let creator = creators[modelType] else {
2930
throw ModelFactoryError.unsupportedModelType(modelType)
3031
}

Libraries/MLXLMCommon/Registries/ProcessorTypeRegistry.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,26 @@ public actor ProcessorTypeRegistry {
1111
}
1212

1313
/// Creates a registry with given creators.
14-
public init(creators: [String: (URL, any Tokenizer) throws -> any UserInputProcessor]) {
14+
public init(creators: [String: (Data, any Tokenizer) throws -> any UserInputProcessor]) {
1515
self.creators = creators
1616
}
1717

18-
private var creators: [String: (URL, any Tokenizer) throws -> any UserInputProcessor]
18+
private var creators: [String: (Data, any Tokenizer) throws -> any UserInputProcessor]
1919

2020
/// Add a new model to the type registry.
2121
public func registerProcessorType(
2222
_ type: String,
2323
creator:
2424
@escaping (
25-
URL,
25+
Data,
2626
any Tokenizer
2727
) throws -> any UserInputProcessor
2828
) {
2929
creators[type] = creator
3030
}
3131

32-
/// Given a `processorType` and configuration file instantiate a new `UserInputProcessor`.
33-
public func createModel(configuration: URL, processorType: String, tokenizer: any Tokenizer)
32+
/// Given a `processorType` and configuration data instantiate a new `UserInputProcessor`.
33+
public func createModel(configuration: Data, processorType: String, tokenizer: any Tokenizer)
3434
throws -> sending any UserInputProcessor
3535
{
3636
guard let creator = creators[processorType] else {

0 commit comments

Comments
 (0)