Skip to content

Commit 0417a58

Browse files
committed
Parallelize loading of weights, tokenizer, and processor config
1 parent 54cf942 commit 0417a58

File tree

4 files changed

+48
-59
lines changed

4 files changed

+48
-59
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@ import MLX
66
import MLXLMCommon
77
import Tokenizers
88

9-
/// Creates a function that loads a configuration file and instantiates a model with the proper configuration
9+
/// Creates a function that decodes configuration data and instantiates a model with the proper configuration
1010
private func create<C: Codable, M>(
1111
_ configurationType: C.Type, _ modelInit: @escaping (C) -> M
12-
) -> (URL) throws -> M {
13-
{ url in
14-
let configuration = try JSONDecoder().decode(
15-
C.self, from: Data(contentsOf: url))
12+
) -> (Data) throws -> M {
13+
{ data in
14+
let configuration = try JSONDecoder().decode(C.self, from: data)
1615
return modelInit(configuration)
1716
}
1817
}
@@ -478,13 +477,13 @@ public final class LLMModelFactory: ModelFactory {
478477
let modelDirectory = try await downloadModel(
479478
hub: hub, configuration: configuration, progressHandler: progressHandler)
480479

481-
// Load the generic config to understand which model and how to load the weights
480+
// Load config.json once and decode for both base config and model-specific config
482481
let configurationURL = modelDirectory.appending(component: "config.json")
483-
482+
let configData: Data
484483
let baseConfig: BaseConfiguration
485484
do {
486-
baseConfig = try JSONDecoder().decode(
487-
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
485+
configData = try Data(contentsOf: configurationURL)
486+
baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
488487
} catch let error as DecodingError {
489488
throw ModelFactoryError.configurationDecodingError(
490489
configurationURL.lastPathComponent, configuration.name, error)
@@ -493,18 +492,20 @@ public final class LLMModelFactory: ModelFactory {
493492
let model: LanguageModel
494493
do {
495494
model = try await typeRegistry.createModel(
496-
configuration: configurationURL, modelType: baseConfig.modelType)
495+
configuration: configData, modelType: baseConfig.modelType)
497496
} catch let error as DecodingError {
498497
throw ModelFactoryError.configurationDecodingError(
499498
configurationURL.lastPathComponent, configuration.name, error)
500499
}
501500

502-
// apply the weights to the bare model
501+
// Load weights and tokenizer in parallel
502+
async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
503+
503504
try loadWeights(
504505
modelDirectory: modelDirectory, model: model,
505506
perLayerQuantization: baseConfig.perLayerQuantization)
506507

507-
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
508+
let tokenizer = try await tokenizerTask
508509

509510
let messageGenerator =
510511
if let model = model as? LLMModel {

Libraries/MLXLMCommon/Registries/ModelTypeRegistry.swift

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,22 @@ public actor ModelTypeRegistry {
1010
}
1111

1212
/// Creates a registry with given creators.
13-
public init(creators: [String: (URL) throws -> any LanguageModel]) {
13+
public init(creators: [String: (Data) throws -> any LanguageModel]) {
1414
self.creators = creators
1515
}
1616

17-
private var creators: [String: (URL) throws -> any LanguageModel]
17+
private var creators: [String: (Data) throws -> any LanguageModel]
1818

1919
/// Add a new model to the type registry.
2020
public func registerModelType(
21-
_ type: String, creator: @escaping (URL) throws -> any LanguageModel
21+
_ type: String, creator: @escaping (Data) throws -> any LanguageModel
2222
) {
2323
creators[type] = creator
2424
}
2525

26-
/// Given a `modelType` and configuration file instantiate a new `LanguageModel`.
27-
public func createModel(configuration: URL, modelType: String) throws -> sending LanguageModel {
26+
/// Given a `modelType` and configuration data instantiate a new `LanguageModel`.
27+
public func createModel(configuration: Data, modelType: String) throws -> sending LanguageModel
28+
{
2829
guard let creator = creators[modelType] else {
2930
throw ModelFactoryError.unsupportedModelType(modelType)
3031
}

Libraries/MLXLMCommon/Registries/ProcessorTypeRegistry.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,26 @@ public actor ProcessorTypeRegistry {
1111
}
1212

1313
/// Creates a registry with given creators.
14-
public init(creators: [String: (URL, any Tokenizer) throws -> any UserInputProcessor]) {
14+
public init(creators: [String: (Data, any Tokenizer) throws -> any UserInputProcessor]) {
1515
self.creators = creators
1616
}
1717

18-
private var creators: [String: (URL, any Tokenizer) throws -> any UserInputProcessor]
18+
private var creators: [String: (Data, any Tokenizer) throws -> any UserInputProcessor]
1919

2020
/// Add a new model to the type registry.
2121
public func registerProcessorType(
2222
_ type: String,
2323
creator:
2424
@escaping (
25-
URL,
25+
Data,
2626
any Tokenizer
2727
) throws -> any UserInputProcessor
2828
) {
2929
creators[type] = creator
3030
}
3131

32-
/// Given a `processorType` and configuration file instantiate a new `UserInputProcessor`.
33-
public func createModel(configuration: URL, processorType: String, tokenizer: any Tokenizer)
32+
/// Given a `processorType` and configuration data instantiate a new `UserInputProcessor`.
33+
public func createModel(configuration: Data, processorType: String, tokenizer: any Tokenizer)
3434
throws -> sending any UserInputProcessor
3535
{
3636
guard let creator = creators[processorType] else {

Libraries/MLXVLM/VLMModelFactory.swift

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,9 @@ public struct BaseProcessorConfiguration: Codable, Sendable {
4848
/// Creates a function that loads a configuration file and instantiates a model with the proper configuration
4949
private func create<C: Codable, M>(
5050
_ configurationType: C.Type, _ modelInit: @escaping (C) -> M
51-
) -> (URL) throws -> M {
52-
{ url in
53-
let configuration = try JSONDecoder().decode(
54-
C.self, from: Data(contentsOf: url))
51+
) -> (Data) throws -> M {
52+
{ data in
53+
let configuration = try JSONDecoder().decode(C.self, from: data)
5554
return modelInit(configuration)
5655
}
5756
}
@@ -63,10 +62,9 @@ private func create<C: Codable, P>(
6362
C,
6463
any Tokenizer
6564
) -> P
66-
) -> (URL, any Tokenizer) throws -> P {
67-
{ url, tokenizer in
68-
let configuration = try JSONDecoder().decode(
69-
C.self, from: Data(contentsOf: url))
65+
) -> (Data, any Tokenizer) throws -> P {
66+
{ data, tokenizer in
67+
let configuration = try JSONDecoder().decode(C.self, from: data)
7068
return processorInit(configuration, tokenizer)
7169
}
7270
}
@@ -247,15 +245,13 @@ public final class VLMModelFactory: ModelFactory {
247245
let modelDirectory = try await downloadModel(
248246
hub: hub, configuration: configuration, progressHandler: progressHandler)
249247

250-
// load the generic config to understand which model and how to load the weights
251-
let configurationURL = modelDirectory.appending(
252-
component: "config.json"
253-
)
254-
248+
// Load config.json once and decode for both base config and model-specific config
249+
let configurationURL = modelDirectory.appending(component: "config.json")
250+
let configData: Data
255251
let baseConfig: BaseConfiguration
256252
do {
257-
baseConfig = try JSONDecoder().decode(
258-
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
253+
configData = try Data(contentsOf: configurationURL)
254+
baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
259255
} catch let error as DecodingError {
260256
throw ModelFactoryError.configurationDecodingError(
261257
configurationURL.lastPathComponent, configuration.name, error)
@@ -264,39 +260,30 @@ public final class VLMModelFactory: ModelFactory {
264260
let model: LanguageModel
265261
do {
266262
model = try await typeRegistry.createModel(
267-
configuration: configurationURL, modelType: baseConfig.modelType)
263+
configuration: configData, modelType: baseConfig.modelType)
268264
} catch let error as DecodingError {
269265
throw ModelFactoryError.configurationDecodingError(
270266
configurationURL.lastPathComponent, configuration.name, error)
271267
}
272268

273-
// apply the weights to the bare model
269+
// Load weights, tokenizer, and processor config in parallel
270+
async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
271+
async let processorConfigTask: (Data, BaseProcessorConfiguration) = {
272+
let url = modelDirectory.appending(component: "preprocessor_config.json")
273+
let data = try Data(contentsOf: url)
274+
let config = try JSONDecoder().decode(BaseProcessorConfiguration.self, from: data)
275+
return (data, config)
276+
}()
277+
274278
try loadWeights(
275279
modelDirectory: modelDirectory, model: model,
276280
perLayerQuantization: baseConfig.perLayerQuantization)
277281

278-
let tokenizer = try await loadTokenizer(
279-
configuration: configuration,
280-
hub: hub
281-
)
282-
283-
let processorConfigurationURL = modelDirectory.appending(
284-
component: "preprocessor_config.json"
285-
)
286-
287-
let baseProcessorConfig: BaseProcessorConfiguration
288-
do {
289-
baseProcessorConfig = try JSONDecoder().decode(
290-
BaseProcessorConfiguration.self,
291-
from: Data(contentsOf: processorConfigurationURL)
292-
)
293-
} catch let error as DecodingError {
294-
throw ModelFactoryError.configurationDecodingError(
295-
processorConfigurationURL.lastPathComponent, configuration.name, error)
296-
}
282+
let tokenizer = try await tokenizerTask
283+
let (processorConfigData, baseProcessorConfig) = try await processorConfigTask
297284

298285
let processor = try await processorRegistry.createModel(
299-
configuration: processorConfigurationURL,
286+
configuration: processorConfigData,
300287
processorType: baseProcessorConfig.processorClass, tokenizer: tokenizer)
301288

302289
return .init(

0 commit comments

Comments
 (0)