@@ -6,8 +6,48 @@ import MLX
66import MLXNN
77import Tokenizers
88
9- struct EmbedderError : Error {
10- let message : String
9+ public enum EmbedderError : LocalizedError {
10+ case unsupportedModelType( String )
11+ case configurationFileError( String , String , Error )
12+ case configurationDecodingError( String , String , DecodingError )
13+ case missingTokenizerConfig
14+
15+ public var errorDescription : String ? {
16+ switch self {
17+ case . unsupportedModelType( let type) :
18+ return " Unsupported model type: \( type) "
19+ case . configurationFileError( let file, let modelName, let error) :
20+ return " Error reading ' \( file) ' for model ' \( modelName) ': \( error. localizedDescription) "
21+ case . configurationDecodingError( let file, let modelName, let decodingError) :
22+ let errorDetail = extractDecodingErrorDetail ( decodingError)
23+ return " Failed to parse \( file) for model ' \( modelName) ': \( errorDetail) "
24+ case . missingTokenizerConfig:
25+ return " Missing tokenizer configuration "
26+ }
27+ }
28+
29+ private func extractDecodingErrorDetail( _ error: DecodingError ) -> String {
30+ switch error {
31+ case . keyNotFound( let key, let context) :
32+ let path = ( context. codingPath + [ key] ) . map { $0. stringValue } . joined ( separator: " . " )
33+ return " Missing field ' \( path) ' "
34+ case . typeMismatch( _, let context) :
35+ let path = context. codingPath. map { $0. stringValue } . joined ( separator: " . " )
36+ return " Type mismatch at ' \( path) ' "
37+ case . valueNotFound( _, let context) :
38+ let path = context. codingPath. map { $0. stringValue } . joined ( separator: " . " )
39+ return " Missing value at ' \( path) ' "
40+ case . dataCorrupted( let context) :
41+ if context. codingPath. isEmpty {
42+ return " Invalid JSON "
43+ } else {
44+ let path = context. codingPath. map { $0. stringValue } . joined ( separator: " . " )
45+ return " Invalid data at ' \( path) ' "
46+ }
47+ @unknown default :
48+ return error. localizedDescription
49+ }
50+ }
1151}
1252
1353func prepareModelDirectory(
@@ -49,20 +89,41 @@ public func load(
4989) async throws -> ( EmbeddingModel , Tokenizer ) {
5090 let modelDirectory = try await prepareModelDirectory (
5191 hub: hub, configuration: configuration, progressHandler: progressHandler)
52- let model = try loadSynchronous ( modelDirectory: modelDirectory)
53- let tokenizer = try await loadTokenizer ( configuration: configuration, hub: hub)
92+
93+ // Load tokenizer and model in parallel using async let.
94+ async let tokenizerTask = loadTokenizer ( configuration: configuration, hub: hub)
95+ let model = try loadSynchronous ( modelDirectory: modelDirectory, modelName: configuration. name)
96+ let tokenizer = try await tokenizerTask
5497
5598 return ( model, tokenizer)
5699}
57100
58- func loadSynchronous( modelDirectory: URL ) throws -> EmbeddingModel {
59- // create the model (no weights loaded)
101+ func loadSynchronous( modelDirectory: URL , modelName : String ) throws -> EmbeddingModel {
102+ // Load config.json once and decode for both base config and model-specific config
60103 let configurationURL = modelDirectory. appending ( component: " config.json " )
61- let baseConfig = try JSONDecoder ( ) . decode (
62- BaseConfiguration . self, from: Data ( contentsOf: configurationURL) )
104+ let configData : Data
105+ do {
106+ configData = try Data ( contentsOf: configurationURL)
107+ } catch {
108+ throw EmbedderError . configurationFileError (
109+ configurationURL. lastPathComponent, modelName, error)
110+ }
111+ let baseConfig : BaseConfiguration
112+ do {
113+ baseConfig = try JSONDecoder ( ) . decode ( BaseConfiguration . self, from: configData)
114+ } catch let error as DecodingError {
115+ throw EmbedderError . configurationDecodingError (
116+ configurationURL. lastPathComponent, modelName, error)
117+ }
63118
64119 let modelType = ModelType ( rawValue: baseConfig. modelType)
65- let model = try modelType. createModel ( configuration: configurationURL)
120+ let model : EmbeddingModel
121+ do {
122+ model = try modelType. createModel ( configuration: configData)
123+ } catch let error as DecodingError {
124+ throw EmbedderError . configurationDecodingError (
125+ configurationURL. lastPathComponent, modelName, error)
126+ }
66127
67128 // load the weights
68129 var weights = [ String: MLXArray] ( )
0 commit comments