Skip to content

Commit afb9dec

Browse files
committed
Parallelize loading of weights, tokenizer, and processor config
1 parent 54cf942 commit afb9dec

File tree

7 files changed

+83
-108
lines changed

7 files changed

+83
-108
lines changed

Libraries/Embedders/Configuration.swift

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -33,60 +33,42 @@ private class ModelTypeRegistry: @unchecked Sendable {
3333
// to remain synchronous.
3434
private let lock = NSLock()
3535

36-
private var creators: [String: @Sendable (URL) throws -> EmbeddingModel] = [
37-
"bert": {
38-
url in
39-
let configuration = try JSONDecoder().decode(
40-
BertConfiguration.self, from: Data(contentsOf: url))
41-
let model = BertModel(configuration)
42-
return model
36+
private var creators: [String: @Sendable (Data) throws -> EmbeddingModel] = [
37+
"bert": { data in
38+
let configuration = try JSONDecoder().decode(BertConfiguration.self, from: data)
39+
return BertModel(configuration)
4340
},
44-
"roberta": {
45-
url in
46-
let configuration = try JSONDecoder().decode(
47-
BertConfiguration.self, from: Data(contentsOf: url))
48-
let model = BertModel(configuration)
49-
return model
41+
"roberta": { data in
42+
let configuration = try JSONDecoder().decode(BertConfiguration.self, from: data)
43+
return BertModel(configuration)
5044
},
51-
"xlm-roberta": {
52-
url in
53-
let configuration = try JSONDecoder().decode(
54-
BertConfiguration.self, from: Data(contentsOf: url))
55-
let model = BertModel(configuration)
56-
return model
45+
"xlm-roberta": { data in
46+
let configuration = try JSONDecoder().decode(BertConfiguration.self, from: data)
47+
return BertModel(configuration)
5748
},
58-
"distilbert": {
59-
url in
60-
let configuration = try JSONDecoder().decode(
61-
BertConfiguration.self, from: Data(contentsOf: url))
62-
let model = BertModel(configuration)
63-
return model
49+
"distilbert": { data in
50+
let configuration = try JSONDecoder().decode(BertConfiguration.self, from: data)
51+
return BertModel(configuration)
6452
},
65-
"nomic_bert": {
66-
url in
67-
let configuration = try JSONDecoder().decode(
68-
NomicBertConfiguration.self, from: Data(contentsOf: url))
69-
let model = NomicBertModel(configuration, pooler: false)
70-
return model
53+
"nomic_bert": { data in
54+
let configuration = try JSONDecoder().decode(NomicBertConfiguration.self, from: data)
55+
return NomicBertModel(configuration, pooler: false)
7156
},
72-
"qwen3": {
73-
url in
74-
let configuration = try JSONDecoder().decode(
75-
Qwen3Configuration.self, from: Data(contentsOf: url))
76-
let model = Qwen3Model(configuration)
77-
return model
57+
"qwen3": { data in
58+
let configuration = try JSONDecoder().decode(Qwen3Configuration.self, from: data)
59+
return Qwen3Model(configuration)
7860
},
7961
]
8062

8163
public func registerModelType(
82-
_ type: String, creator: @Sendable @escaping (URL) throws -> EmbeddingModel
64+
_ type: String, creator: @Sendable @escaping (Data) throws -> EmbeddingModel
8365
) {
8466
lock.withLock {
8567
creators[type] = creator
8668
}
8769
}
8870

89-
public func createModel(configuration: URL, rawValue: String) throws -> EmbeddingModel {
71+
public func createModel(configuration: Data, rawValue: String) throws -> EmbeddingModel {
9072
let creator = lock.withLock {
9173
creators[rawValue]
9274
}
@@ -108,12 +90,12 @@ public struct ModelType: RawRepresentable, Codable, Sendable {
10890
}
10991

11092
public static func registerModelType(
111-
_ type: String, creator: @Sendable @escaping (URL) throws -> EmbeddingModel
93+
_ type: String, creator: @Sendable @escaping (Data) throws -> EmbeddingModel
11294
) {
11395
modelTypeRegistry.registerModelType(type, creator: creator)
11496
}
11597

116-
public func createModel(configuration: URL) throws -> EmbeddingModel {
98+
public func createModel(configuration: Data) throws -> EmbeddingModel {
11799
try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue)
118100
}
119101
}

Libraries/Embedders/EmbeddingModel.swift

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ public actor ModelContainer {
4646
public init(
4747
hub: HubApi, modelDirectory: URL, configuration: ModelConfiguration
4848
) async throws {
49+
async let tokenizerConfigTask = loadTokenizerConfig(
50+
configuration: configuration, hub: hub)
51+
4952
self.model = try loadSynchronous(modelDirectory: modelDirectory)
53+
self.pooler = loadPooling(modelDirectory: modelDirectory)
5054

51-
let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig(
52-
configuration: configuration, hub: hub)
55+
let (tokenizerConfig, tokenizerData) = try await tokenizerConfigTask
5356
self.tokenizer = try PreTrainedTokenizer(
5457
tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
55-
self.pooler = loadPooling(modelDirectory: modelDirectory) //?? Pooling(strategy: .none)
5658
}
5759

5860
/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as

Libraries/Embedders/Load.swift

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,22 @@ public func load(
4949
) async throws -> (EmbeddingModel, Tokenizer) {
5050
let modelDirectory = try await prepareModelDirectory(
5151
hub: hub, configuration: configuration, progressHandler: progressHandler)
52+
53+
async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
5254
let model = try loadSynchronous(modelDirectory: modelDirectory)
53-
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
55+
let tokenizer = try await tokenizerTask
5456

5557
return (model, tokenizer)
5658
}
5759

5860
func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
59-
// create the model (no weights loaded)
61+
// Load config.json once and decode for both base config and model-specific config
6062
let configurationURL = modelDirectory.appending(component: "config.json")
61-
let baseConfig = try JSONDecoder().decode(
62-
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
63+
let configData = try Data(contentsOf: configurationURL)
64+
let baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
6365

6466
let modelType = ModelType(rawValue: baseConfig.modelType)
65-
let model = try modelType.createModel(configuration: configurationURL)
67+
let model = try modelType.createModel(configuration: configData)
6668

6769
// load the weights
6870
var weights = [String: MLXArray]()

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@ import MLX
66
import MLXLMCommon
77
import Tokenizers
88

9-
/// Creates a function that loads a configuration file and instantiates a model with the proper configuration
9+
/// Creates a function that decodes configuration data and instantiates a model with the proper configuration
1010
private func create<C: Codable, M>(
1111
_ configurationType: C.Type, _ modelInit: @escaping (C) -> M
12-
) -> (URL) throws -> M {
13-
{ url in
14-
let configuration = try JSONDecoder().decode(
15-
C.self, from: Data(contentsOf: url))
12+
) -> (Data) throws -> M {
13+
{ data in
14+
let configuration = try JSONDecoder().decode(C.self, from: data)
1615
return modelInit(configuration)
1716
}
1817
}
@@ -478,13 +477,13 @@ public final class LLMModelFactory: ModelFactory {
478477
let modelDirectory = try await downloadModel(
479478
hub: hub, configuration: configuration, progressHandler: progressHandler)
480479

481-
// Load the generic config to understand which model and how to load the weights
480+
// Load config.json once and decode for both base config and model-specific config
482481
let configurationURL = modelDirectory.appending(component: "config.json")
483-
482+
let configData: Data
484483
let baseConfig: BaseConfiguration
485484
do {
486-
baseConfig = try JSONDecoder().decode(
487-
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
485+
configData = try Data(contentsOf: configurationURL)
486+
baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
488487
} catch let error as DecodingError {
489488
throw ModelFactoryError.configurationDecodingError(
490489
configurationURL.lastPathComponent, configuration.name, error)
@@ -493,18 +492,20 @@ public final class LLMModelFactory: ModelFactory {
493492
let model: LanguageModel
494493
do {
495494
model = try await typeRegistry.createModel(
496-
configuration: configurationURL, modelType: baseConfig.modelType)
495+
configuration: configData, modelType: baseConfig.modelType)
497496
} catch let error as DecodingError {
498497
throw ModelFactoryError.configurationDecodingError(
499498
configurationURL.lastPathComponent, configuration.name, error)
500499
}
501500

502-
// apply the weights to the bare model
501+
// Load weights and tokenizer in parallel
502+
async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
503+
503504
try loadWeights(
504505
modelDirectory: modelDirectory, model: model,
505506
perLayerQuantization: baseConfig.perLayerQuantization)
506507

507-
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
508+
let tokenizer = try await tokenizerTask
508509

509510
let messageGenerator =
510511
if let model = model as? LLMModel {

Libraries/MLXLMCommon/Registries/ModelTypeRegistry.swift

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,22 @@ public actor ModelTypeRegistry {
1010
}
1111

1212
/// Creates a registry with given creators.
13-
public init(creators: [String: (URL) throws -> any LanguageModel]) {
13+
public init(creators: [String: (Data) throws -> any LanguageModel]) {
1414
self.creators = creators
1515
}
1616

17-
private var creators: [String: (URL) throws -> any LanguageModel]
17+
private var creators: [String: (Data) throws -> any LanguageModel]
1818

1919
/// Add a new model to the type registry.
2020
public func registerModelType(
21-
_ type: String, creator: @escaping (URL) throws -> any LanguageModel
21+
_ type: String, creator: @escaping (Data) throws -> any LanguageModel
2222
) {
2323
creators[type] = creator
2424
}
2525

26-
/// Given a `modelType` and configuration file instantiate a new `LanguageModel`.
27-
public func createModel(configuration: URL, modelType: String) throws -> sending LanguageModel {
26+
/// Given a `modelType` and configuration data instantiate a new `LanguageModel`.
27+
public func createModel(configuration: Data, modelType: String) throws -> sending LanguageModel
28+
{
2829
guard let creator = creators[modelType] else {
2930
throw ModelFactoryError.unsupportedModelType(modelType)
3031
}

Libraries/MLXLMCommon/Registries/ProcessorTypeRegistry.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,26 @@ public actor ProcessorTypeRegistry {
1111
}
1212

1313
/// Creates a registry with given creators.
14-
public init(creators: [String: (URL, any Tokenizer) throws -> any UserInputProcessor]) {
14+
public init(creators: [String: (Data, any Tokenizer) throws -> any UserInputProcessor]) {
1515
self.creators = creators
1616
}
1717

18-
private var creators: [String: (URL, any Tokenizer) throws -> any UserInputProcessor]
18+
private var creators: [String: (Data, any Tokenizer) throws -> any UserInputProcessor]
1919

2020
/// Add a new model to the type registry.
2121
public func registerProcessorType(
2222
_ type: String,
2323
creator:
2424
@escaping (
25-
URL,
25+
Data,
2626
any Tokenizer
2727
) throws -> any UserInputProcessor
2828
) {
2929
creators[type] = creator
3030
}
3131

32-
/// Given a `processorType` and configuration file instantiate a new `UserInputProcessor`.
33-
public func createModel(configuration: URL, processorType: String, tokenizer: any Tokenizer)
32+
/// Given a `processorType` and configuration data instantiate a new `UserInputProcessor`.
33+
public func createModel(configuration: Data, processorType: String, tokenizer: any Tokenizer)
3434
throws -> sending any UserInputProcessor
3535
{
3636
guard let creator = creators[processorType] else {

Libraries/MLXVLM/VLMModelFactory.swift

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,9 @@ public struct BaseProcessorConfiguration: Codable, Sendable {
4848
/// Creates a function that loads a configuration file and instantiates a model with the proper configuration
4949
private func create<C: Codable, M>(
5050
_ configurationType: C.Type, _ modelInit: @escaping (C) -> M
51-
) -> (URL) throws -> M {
52-
{ url in
53-
let configuration = try JSONDecoder().decode(
54-
C.self, from: Data(contentsOf: url))
51+
) -> (Data) throws -> M {
52+
{ data in
53+
let configuration = try JSONDecoder().decode(C.self, from: data)
5554
return modelInit(configuration)
5655
}
5756
}
@@ -63,10 +62,9 @@ private func create<C: Codable, P>(
6362
C,
6463
any Tokenizer
6564
) -> P
66-
) -> (URL, any Tokenizer) throws -> P {
67-
{ url, tokenizer in
68-
let configuration = try JSONDecoder().decode(
69-
C.self, from: Data(contentsOf: url))
65+
) -> (Data, any Tokenizer) throws -> P {
66+
{ data, tokenizer in
67+
let configuration = try JSONDecoder().decode(C.self, from: data)
7068
return processorInit(configuration, tokenizer)
7169
}
7270
}
@@ -247,15 +245,13 @@ public final class VLMModelFactory: ModelFactory {
247245
let modelDirectory = try await downloadModel(
248246
hub: hub, configuration: configuration, progressHandler: progressHandler)
249247

250-
// load the generic config to understand which model and how to load the weights
251-
let configurationURL = modelDirectory.appending(
252-
component: "config.json"
253-
)
254-
248+
// Load config.json once and decode for both base config and model-specific config
249+
let configurationURL = modelDirectory.appending(component: "config.json")
250+
let configData: Data
255251
let baseConfig: BaseConfiguration
256252
do {
257-
baseConfig = try JSONDecoder().decode(
258-
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
253+
configData = try Data(contentsOf: configurationURL)
254+
baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
259255
} catch let error as DecodingError {
260256
throw ModelFactoryError.configurationDecodingError(
261257
configurationURL.lastPathComponent, configuration.name, error)
@@ -264,39 +260,30 @@ public final class VLMModelFactory: ModelFactory {
264260
let model: LanguageModel
265261
do {
266262
model = try await typeRegistry.createModel(
267-
configuration: configurationURL, modelType: baseConfig.modelType)
263+
configuration: configData, modelType: baseConfig.modelType)
268264
} catch let error as DecodingError {
269265
throw ModelFactoryError.configurationDecodingError(
270266
configurationURL.lastPathComponent, configuration.name, error)
271267
}
272268

273-
// apply the weights to the bare model
269+
// Load weights, tokenizer, and processor config in parallel
270+
async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
271+
async let processorConfigTask: (Data, BaseProcessorConfiguration) = {
272+
let url = modelDirectory.appending(component: "preprocessor_config.json")
273+
let data = try Data(contentsOf: url)
274+
let config = try JSONDecoder().decode(BaseProcessorConfiguration.self, from: data)
275+
return (data, config)
276+
}()
277+
274278
try loadWeights(
275279
modelDirectory: modelDirectory, model: model,
276280
perLayerQuantization: baseConfig.perLayerQuantization)
277281

278-
let tokenizer = try await loadTokenizer(
279-
configuration: configuration,
280-
hub: hub
281-
)
282-
283-
let processorConfigurationURL = modelDirectory.appending(
284-
component: "preprocessor_config.json"
285-
)
286-
287-
let baseProcessorConfig: BaseProcessorConfiguration
288-
do {
289-
baseProcessorConfig = try JSONDecoder().decode(
290-
BaseProcessorConfiguration.self,
291-
from: Data(contentsOf: processorConfigurationURL)
292-
)
293-
} catch let error as DecodingError {
294-
throw ModelFactoryError.configurationDecodingError(
295-
processorConfigurationURL.lastPathComponent, configuration.name, error)
296-
}
282+
let tokenizer = try await tokenizerTask
283+
let (processorConfigData, baseProcessorConfig) = try await processorConfigTask
297284

298285
let processor = try await processorRegistry.createModel(
299-
configuration: processorConfigurationURL,
286+
configuration: processorConfigData,
300287
processorType: baseProcessorConfig.processorClass, tokenizer: tokenizer)
301288

302289
return .init(

0 commit comments

Comments
 (0)