Skip to content

Commit 3ea6266

Browse files
committed
Improve embedding model error handling
1 parent afb9dec commit 3ea6266

File tree

4 files changed

+61
-10
lines changed

4 files changed

+61
-10
lines changed

Libraries/Embedders/Configuration.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ private class ModelTypeRegistry: @unchecked Sendable {
7373
creators[rawValue]
7474
}
7575
guard let creator else {
76-
throw EmbedderError(message: "Unsupported model type.")
76+
throw EmbedderError.unsupportedModelType(rawValue)
7777
}
7878
return try creator(configuration)
7979
}

Libraries/Embedders/EmbeddingModel.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ public actor ModelContainer {
4949
async let tokenizerConfigTask = loadTokenizerConfig(
5050
configuration: configuration, hub: hub)
5151

52-
self.model = try loadSynchronous(modelDirectory: modelDirectory)
52+
self.model = try loadSynchronous(
53+
modelDirectory: modelDirectory, modelName: configuration.name)
5354
self.pooler = loadPooling(modelDirectory: modelDirectory)
5455

5556
let (tokenizerConfig, tokenizerData) = try await tokenizerConfigTask

Libraries/Embedders/Load.swift

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,45 @@ import MLX
66
import MLXNN
77
import Tokenizers
88

9-
struct EmbedderError: Error {
10-
let message: String
9+
public enum EmbedderError: LocalizedError {
10+
case unsupportedModelType(String)
11+
case configurationDecodingError(String, String, DecodingError)
12+
case missingTokenizerConfig
13+
14+
public var errorDescription: String? {
15+
switch self {
16+
case .unsupportedModelType(let type):
17+
return "Unsupported model type: \(type)"
18+
case .configurationDecodingError(let file, let modelName, let decodingError):
19+
let errorDetail = extractDecodingErrorDetail(decodingError)
20+
return "Failed to parse \(file) for model '\(modelName)': \(errorDetail)"
21+
case .missingTokenizerConfig:
22+
return "Missing tokenizer configuration"
23+
}
24+
}
25+
26+
private func extractDecodingErrorDetail(_ error: DecodingError) -> String {
27+
switch error {
28+
case .keyNotFound(let key, let context):
29+
let path = (context.codingPath + [key]).map { $0.stringValue }.joined(separator: ".")
30+
return "Missing field '\(path)'"
31+
case .typeMismatch(_, let context):
32+
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
33+
return "Type mismatch at '\(path)'"
34+
case .valueNotFound(_, let context):
35+
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
36+
return "Missing value at '\(path)'"
37+
case .dataCorrupted(let context):
38+
if context.codingPath.isEmpty {
39+
return "Invalid JSON"
40+
} else {
41+
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
42+
return "Invalid data at '\(path)'"
43+
}
44+
@unknown default:
45+
return error.localizedDescription
46+
}
47+
}
1148
}
1249

1350
func prepareModelDirectory(
@@ -51,20 +88,33 @@ public func load(
5188
hub: hub, configuration: configuration, progressHandler: progressHandler)
5289

5390
async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
54-
let model = try loadSynchronous(modelDirectory: modelDirectory)
91+
let model = try loadSynchronous(modelDirectory: modelDirectory, modelName: configuration.name)
5592
let tokenizer = try await tokenizerTask
5693

5794
return (model, tokenizer)
5895
}
5996

60-
func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
97+
func loadSynchronous(modelDirectory: URL, modelName: String) throws -> EmbeddingModel {
6198
// Load config.json once and decode for both base config and model-specific config
6299
let configurationURL = modelDirectory.appending(component: "config.json")
63-
let configData = try Data(contentsOf: configurationURL)
64-
let baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
100+
let configData: Data
101+
let baseConfig: BaseConfiguration
102+
do {
103+
configData = try Data(contentsOf: configurationURL)
104+
baseConfig = try JSONDecoder().decode(BaseConfiguration.self, from: configData)
105+
} catch let error as DecodingError {
106+
throw EmbedderError.configurationDecodingError(
107+
configurationURL.lastPathComponent, modelName, error)
108+
}
65109

66110
let modelType = ModelType(rawValue: baseConfig.modelType)
67-
let model = try modelType.createModel(configuration: configData)
111+
let model: EmbeddingModel
112+
do {
113+
model = try modelType.createModel(configuration: configData)
114+
} catch let error as DecodingError {
115+
throw EmbedderError.configurationDecodingError(
116+
configurationURL.lastPathComponent, modelName, error)
117+
}
68118

69119
// load the weights
70120
var weights = [String: MLXArray]()

Libraries/Embedders/Tokenizer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async t
4545
}
4646

4747
guard let tokenizerConfig = try await config.tokenizerConfig else {
48-
throw EmbedderError(message: "missing config")
48+
throw EmbedderError.missingTokenizerConfig
4949
}
5050
let tokenizerData = try await config.tokenizerData
5151
return (tokenizerConfig, tokenizerData)

0 commit comments

Comments
 (0)