Skip to content

Commit 49d2e79

Browse files
committed
Improve embedding model error handling
1 parent afb9dec commit 49d2e79

File tree

3 files changed

+57
-9
lines changed

3 files changed

+57
-9
lines changed

Libraries/Embedders/Configuration.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ private class ModelTypeRegistry: @unchecked Sendable {
7373
creators[rawValue]
7474
}
7575
guard let creator else {
76-
throw EmbedderError(message: "Unsupported model type.")
76+
throw EmbedderError.unsupportedModelType(rawValue)
7777
}
7878
return try creator(configuration)
7979
}

Libraries/Embedders/EmbeddingModel.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ public actor ModelContainer {
4949
async let tokenizerConfigTask = loadTokenizerConfig(
5050
configuration: configuration, hub: hub)
5151

52-
self.model = try loadSynchronous(modelDirectory: modelDirectory)
52+
self.model = try loadSynchronous(
53+
modelDirectory: modelDirectory, modelName: configuration.name)
5354
self.pooler = loadPooling(modelDirectory: modelDirectory)
5455

5556
let (tokenizerConfig, tokenizerData) = try await tokenizerConfigTask

Libraries/Embedders/Load.swift

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,42 @@ import MLX
66
import MLXNN
77
import Tokenizers
88

9-
struct EmbedderError: Error {
10-
let message: String
9+
public enum EmbedderError: LocalizedError {
10+
case unsupportedModelType(String)
11+
case configurationDecodingError(String, String, DecodingError)
12+
13+
public var errorDescription: String? {
14+
switch self {
15+
case .unsupportedModelType(let type):
16+
return "Unsupported model type: \(type)"
17+
case .configurationDecodingError(let file, let modelName, let decodingError):
18+
let errorDetail = extractDecodingErrorDetail(decodingError)
19+
return "Failed to parse \(file) for model '\(modelName)': \(errorDetail)"
20+
}
21+
}
22+
23+
private func extractDecodingErrorDetail(_ error: DecodingError) -> String {
24+
switch error {
25+
case .keyNotFound(let key, let context):
26+
let path = (context.codingPath + [key]).map { $0.stringValue }.joined(separator: ".")
27+
return "Missing field '\(path)'"
28+
case .typeMismatch(_, let context):
29+
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
30+
return "Type mismatch at '\(path)'"
31+
case .valueNotFound(_, let context):
32+
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
33+
return "Missing value at '\(path)'"
34+
case .dataCorrupted(let context):
35+
if context.codingPath.isEmpty {
36+
return "Invalid JSON"
37+
} else {
38+
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
39+
return "Invalid data at '\(path)'"
40+
}
41+
@unknown default:
42+
return error.localizedDescription
43+
}
44+
}
1145
}
1246

1347
func prepareModelDirectory(
@@ -51,20 +85,33 @@ public func load(
5185
hub: hub, configuration: configuration, progressHandler: progressHandler)
5286

5387
async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
54-
let model = try loadSynchronous(modelDirectory: modelDirectory)
88+
let model = try loadSynchronous(modelDirectory: modelDirectory, modelName: configuration.name)
5589
let tokenizer = try await tokenizerTask
5690

5791
return (model, tokenizer)
5892
}
5993

60-
func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
94+
func loadSynchronous(modelDirectory: URL, modelName: String) throws -> EmbeddingModel {
6195
// Load config.json once and decode for both base config and model-specific config
6296
let configurationURL = modelDirectory.appending(component: "config.json")
63-
let configData = try Data(contentsOf: configurationURL)
64-
let baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
97+
let configData: Data
98+
let baseConfig: BaseConfiguration
99+
do {
100+
configData = try Data(contentsOf: configurationURL)
101+
baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
102+
} catch let error as DecodingError {
103+
throw EmbedderError.configurationDecodingError(
104+
configurationURL.lastPathComponent, modelName, error)
105+
}
65106

66107
let modelType = ModelType(rawValue: baseConfig.modelType)
67-
let model = try modelType.createModel(configuration: configData)
108+
let model: EmbeddingModel
109+
do {
110+
model = try modelType.createModel(configuration: configData)
111+
} catch let error as DecodingError {
112+
throw EmbedderError.configurationDecodingError(
113+
configurationURL.lastPathComponent, modelName, error)
114+
}
68115

69116
// load the weights
70117
var weights = [String: MLXArray]()

0 commit comments

Comments
 (0)