@@ -6,8 +6,45 @@ import MLX
66import MLXNN
77import Tokenizers
88
9- struct EmbedderError : Error {
10- let message : String
9+ public enum EmbedderError : LocalizedError {
10+ case unsupportedModelType( String )
11+ case configurationDecodingError( String , String , DecodingError )
12+ case missingTokenizerConfig
13+
14+ public var errorDescription : String ? {
15+ switch self {
16+ case . unsupportedModelType( let type) :
17+ return " Unsupported model type: \( type) "
18+ case . configurationDecodingError( let file, let modelName, let decodingError) :
19+ let errorDetail = extractDecodingErrorDetail ( decodingError)
20+ return " Failed to parse \( file) for model ' \( modelName) ': \( errorDetail) "
21+ case . missingTokenizerConfig:
22+ return " Missing tokenizer configuration "
23+ }
24+ }
25+
26+ private func extractDecodingErrorDetail( _ error: DecodingError ) -> String {
27+ switch error {
28+ case . keyNotFound( let key, let context) :
29+ let path = ( context. codingPath + [ key] ) . map { $0. stringValue } . joined ( separator: " . " )
30+ return " Missing field ' \( path) ' "
31+ case . typeMismatch( _, let context) :
32+ let path = context. codingPath. map { $0. stringValue } . joined ( separator: " . " )
33+ return " Type mismatch at ' \( path) ' "
34+ case . valueNotFound( _, let context) :
35+ let path = context. codingPath. map { $0. stringValue } . joined ( separator: " . " )
36+ return " Missing value at ' \( path) ' "
37+ case . dataCorrupted( let context) :
38+ if context. codingPath. isEmpty {
39+ return " Invalid JSON "
40+ } else {
41+ let path = context. codingPath. map { $0. stringValue } . joined ( separator: " . " )
42+ return " Invalid data at ' \( path) ' "
43+ }
44+ @unknown default :
45+ return error. localizedDescription
46+ }
47+ }
1148}
1249
1350func prepareModelDirectory(
@@ -51,20 +88,33 @@ public func load(
5188 hub: hub, configuration: configuration, progressHandler: progressHandler)
5289
5390 async let tokenizerTask = loadTokenizer ( configuration: configuration, hub: hub)
54- let model = try loadSynchronous ( modelDirectory: modelDirectory)
91+ let model = try loadSynchronous ( modelDirectory: modelDirectory, modelName : configuration . name )
5592 let tokenizer = try await tokenizerTask
5693
5794 return ( model, tokenizer)
5895}
5996
60- func loadSynchronous( modelDirectory: URL ) throws -> EmbeddingModel {
97+ func loadSynchronous( modelDirectory: URL , modelName : String ) throws -> EmbeddingModel {
6198 // Load config.json once and decode for both base config and model-specific config
6299 let configurationURL = modelDirectory. appending ( component: " config.json " )
63- let configData = try Data ( contentsOf: configurationURL)
64- let baseConfig = try JSONDecoder ( ) . decode ( BaseConfiguration . self, from: configData)
100+ let configData : Data
101+ let baseConfig : BaseConfiguration
102+ do {
103+ configData = try Data ( contentsOf: configurationURL)
104+ baseConfig = try JSONDecoder ( ) . decode ( BaseConfiguration . self, from: configData)
105+ } catch let error as DecodingError {
106+ throw EmbedderError . configurationDecodingError (
107+ configurationURL. lastPathComponent, modelName, error)
108+ }
65109
66110 let modelType = ModelType ( rawValue: baseConfig. modelType)
67- let model = try modelType. createModel ( configuration: configData)
111+ let model : EmbeddingModel
112+ do {
113+ model = try modelType. createModel ( configuration: configData)
114+ } catch let error as DecodingError {
115+ throw EmbedderError . configurationDecodingError (
116+ configurationURL. lastPathComponent, modelName, error)
117+ }
68118
69119 // load the weights
70120 var weights = [ String: MLXArray] ( )
0 commit comments