@@ -6,8 +6,42 @@ import MLX
66import MLXNN
77import Tokenizers
88
9- struct EmbedderError : Error {
10- let message : String
9+ public enum EmbedderError : LocalizedError {
10+ case unsupportedModelType( String )
11+ case configurationDecodingError( String , String , DecodingError )
12+
13+ public var errorDescription : String ? {
14+ switch self {
15+ case . unsupportedModelType( let type) :
16+ return " Unsupported model type: \( type) "
17+ case . configurationDecodingError( let file, let modelName, let decodingError) :
18+ let errorDetail = extractDecodingErrorDetail ( decodingError)
19+ return " Failed to parse \( file) for model ' \( modelName) ': \( errorDetail) "
20+ }
21+ }
22+
23+ private func extractDecodingErrorDetail( _ error: DecodingError ) -> String {
24+ switch error {
25+ case . keyNotFound( let key, let context) :
26+ let path = ( context. codingPath + [ key] ) . map { $0. stringValue } . joined ( separator: " . " )
27+ return " Missing field ' \( path) ' "
28+ case . typeMismatch( _, let context) :
29+ let path = context. codingPath. map { $0. stringValue } . joined ( separator: " . " )
30+ return " Type mismatch at ' \( path) ' "
31+ case . valueNotFound( _, let context) :
32+ let path = context. codingPath. map { $0. stringValue } . joined ( separator: " . " )
33+ return " Missing value at ' \( path) ' "
34+ case . dataCorrupted( let context) :
35+ if context. codingPath. isEmpty {
36+ return " Invalid JSON "
37+ } else {
38+ let path = context. codingPath. map { $0. stringValue } . joined ( separator: " . " )
39+ return " Invalid data at ' \( path) ' "
40+ }
41+ @unknown default :
42+ return error. localizedDescription
43+ }
44+ }
1145}
1246
1347func prepareModelDirectory(
@@ -51,20 +85,33 @@ public func load(
5185 hub: hub, configuration: configuration, progressHandler: progressHandler)
5286
5387 async let tokenizerTask = loadTokenizer ( configuration: configuration, hub: hub)
54- let model = try loadSynchronous ( modelDirectory: modelDirectory)
88+ let model = try loadSynchronous ( modelDirectory: modelDirectory, modelName : configuration . name )
5589 let tokenizer = try await tokenizerTask
5690
5791 return ( model, tokenizer)
5892}
5993
60- func loadSynchronous( modelDirectory: URL ) throws -> EmbeddingModel {
94+ func loadSynchronous( modelDirectory: URL , modelName : String ) throws -> EmbeddingModel {
6195 // Load config.json once and decode for both base config and model-specific config
6296 let configurationURL = modelDirectory. appending ( component: " config.json " )
63- let configData = try Data ( contentsOf: configurationURL)
64- let baseConfig = try JSONDecoder ( ) . decode ( BaseConfiguration . self, from: configData)
97+ let configData : Data
98+ let baseConfig : BaseConfiguration
99+ do {
100+ configData = try Data ( contentsOf: configurationURL)
101+ baseConfig = try JSONDecoder ( ) . decode ( BaseConfiguration . self, from: configData)
102+ } catch let error as DecodingError {
103+ throw EmbedderError . configurationDecodingError (
104+ configurationURL. lastPathComponent, modelName, error)
105+ }
65106
66107 let modelType = ModelType ( rawValue: baseConfig. modelType)
67- let model = try modelType. createModel ( configuration: configData)
108+ let model : EmbeddingModel
109+ do {
110+ model = try modelType. createModel ( configuration: configData)
111+ } catch let error as DecodingError {
112+ throw EmbedderError . configurationDecodingError (
113+ configurationURL. lastPathComponent, modelName, error)
114+ }
68115
69116 // load the weights
70117 var weights = [ String: MLXArray] ( )
0 commit comments