1
- // Copyright 2020 Google LLC
1
+ // Copyright 2021 Google LLC
2
2
//
3
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
4
// you may not use this file except in compliance with the License.
14
14
15
15
import Foundation
16
16
import FirebaseCore
17
+ import FirebaseInstallations
17
18
18
19
/// Possible errors with model downloading.
19
20
public enum DownloadError : Error , Equatable {
@@ -37,6 +38,8 @@ public enum DownloadedModelError: Error, Equatable {
37
38
case fileIOError( description: String )
38
39
/// Model not found on device.
39
40
case notFound
41
+ /// Other errors with description.
42
+ case internalError( description: String )
40
43
}
41
44
42
45
/// Possible ways to get a custom model.
@@ -53,14 +56,24 @@ public enum ModelDownloadType {
53
56
public class ModelDownloader {
54
57
/// Name of the app associated with this instance of ModelDownloader.
55
58
private let appName : String
59
+ /// Current Firebase app options.
60
+ private let options : FirebaseOptions
61
+ /// Installations instance for current Firebase app.
62
+ private let installations : Installations
63
+ /// User defaults for model info.
64
+ private let userDefaults : UserDefaults
56
65
57
66
/// Shared dictionary mapping app name to a specific instance of model downloader.
58
67
// TODO: Switch to using Firebase components.
59
68
private static var modelDownloaderDictionary : [ String : ModelDownloader ] = [ : ]
60
69
61
70
/// Private init for downloader.
62
- private init ( app: FirebaseApp ) {
71
+ private init ( app: FirebaseApp , defaults : UserDefaults = . firebaseMLDefaults ) {
63
72
appName = app. name
73
+ options = app. options
74
+ installations = Installations . installations ( app: app)
75
+ userDefaults = defaults
76
+
64
77
NotificationCenter . default. addObserver (
65
78
self ,
66
79
selector: #selector( deleteModelDownloader) ,
@@ -103,36 +116,202 @@ public class ModelDownloader {
103
116
conditions: ModelDownloadConditions ,
104
117
progressHandler: ( ( Float ) -> Void ) ? = nil ,
105
118
completion: @escaping ( Result < CustomModel , DownloadError > ) -> Void ) {
106
- // TODO: Model download
107
- let modelSize = Int ( )
108
- let modelPath = String ( )
109
- let modelHash = String ( )
119
+ switch downloadType {
120
+ case . localModel:
121
+ if let localModel = getLocalModel ( modelName: modelName) {
122
+ DispatchQueue . main. async {
123
+ completion ( . success( localModel) )
124
+ }
125
+ } else {
126
+ getRemoteModel (
127
+ modelName: modelName,
128
+ progressHandler: progressHandler,
129
+ completion: completion
130
+ )
131
+ }
132
+ case . localModelUpdateInBackground:
133
+ if let localModel = getLocalModel ( modelName: modelName) {
134
+ DispatchQueue . main. async {
135
+ completion ( . success( localModel) )
136
+ }
137
+ DispatchQueue . global ( qos: . utility) . async { [ weak self] in
138
+ self ? . getRemoteModel (
139
+ modelName: modelName,
140
+ progressHandler: nil ,
141
+ completion: { result in
142
+ switch result {
143
+ // TODO: Log outcome of background download
144
+ case . success: break
145
+ case . failure: break
146
+ }
147
+ }
148
+ )
149
+ }
150
+ } else {
151
+ getRemoteModel (
152
+ modelName: modelName,
153
+ progressHandler: progressHandler,
154
+ completion: completion
155
+ )
156
+ }
110
157
111
- let customModel = CustomModel (
112
- name: modelName,
113
- size: modelSize,
114
- path: modelPath,
115
- hash: modelHash
116
- )
117
- completion ( . success( customModel) )
118
- completion ( . failure( . notFound) )
158
+ case . latestModel:
159
+ getRemoteModel (
160
+ modelName: modelName,
161
+ progressHandler: progressHandler,
162
+ completion: completion
163
+ )
164
+ }
119
165
}
120
166
121
167
/// Gets all downloaded models.
122
168
public func listDownloadedModels( completion: @escaping ( Result < Set < CustomModel > ,
123
169
DownloadedModelError > ) -> Void ) {
124
- let customModels = Set < CustomModel > ( )
125
- // TODO: List downloaded models
126
- completion ( . success( customModels) )
127
- completion ( . failure( . notFound) )
170
+ do {
171
+ let modelPaths = try ModelFileManager . contentsOfModelsDirectory ( )
172
+ var customModels = Set < CustomModel > ( )
173
+ for path in modelPaths {
174
+ guard let modelName = ModelFileManager . getModelNameFromFilePath ( path) else {
175
+ completion ( . failure( . internalError( description: " Invalid model file name. " ) ) )
176
+ return
177
+ }
178
+ guard let modelInfo = getLocalModelInfo ( modelName: modelName) else {
179
+ completion (
180
+ . failure( . internalError( description: " Failed to get model info for model file. " ) )
181
+ )
182
+ return
183
+ }
184
+ guard modelInfo. path == path. absoluteString else {
185
+ completion (
186
+ . failure( . internalError( description: " Outdated model paths in local storage. " ) )
187
+ )
188
+ return
189
+ }
190
+ let model = CustomModel ( localModelInfo: modelInfo)
191
+ customModels. insert ( model)
192
+ }
193
+ completion ( . success( customModels) )
194
+ } catch let error as DownloadedModelError {
195
+ completion ( . failure( error) )
196
+ } catch {
197
+ completion ( . failure( . internalError( description: error. localizedDescription) ) )
198
+ }
128
199
}
129
200
130
201
/// Deletes a custom model from device.
131
202
public func deleteDownloadedModel( name modelName: String ,
132
203
completion: @escaping ( Result < Void , DownloadedModelError > )
133
204
-> Void ) {
134
205
// TODO: Delete previously downloaded model
135
- completion ( . success( ( ) ) )
136
- completion ( . failure( . notFound) )
206
+ guard let localModelInfo = getLocalModelInfo ( modelName: modelName) ,
207
+ let localPath = URL ( string: localModelInfo. path)
208
+ else {
209
+ completion ( . failure( . notFound) )
210
+ return
211
+ }
212
+ do {
213
+ try ModelFileManager . removeFile ( at: localPath)
214
+ completion ( . success( ( ) ) )
215
+ } catch let error as DownloadedModelError {
216
+ completion ( . failure( error) )
217
+ } catch {
218
+ completion ( . failure( . internalError( description: error. localizedDescription) ) )
219
+ }
220
+ }
221
+ }
222
+
223
+ extension ModelDownloader {
224
+ /// Return local model info only if the model info is available and the corresponding model file is already on device.
225
+ private func getLocalModelInfo( modelName: String ) -> LocalModelInfo ? {
226
+ guard let localModelInfo = LocalModelInfo (
227
+ fromDefaults: userDefaults,
228
+ name: modelName,
229
+ appName: appName
230
+ ) else {
231
+ return nil
232
+ }
233
+ /// There is local model info on device, but no model file at the expected path.
234
+ guard let localPath = URL ( string: localModelInfo. path) ,
235
+ ModelFileManager . isFileReachable ( at: localPath) else {
236
+ // TODO: Consider deleting local model info in user defaults.
237
+ return nil
238
+ }
239
+ return localModelInfo
240
+ }
241
+
242
+ /// Get model saved on device if available.
243
+ private func getLocalModel( modelName: String ) -> CustomModel ? {
244
+ guard let localModelInfo = getLocalModelInfo ( modelName: modelName) else { return nil }
245
+ let model = CustomModel ( localModelInfo: localModelInfo)
246
+ return model
247
+ }
248
+
249
+ /// Download and get model from server, unless the latest model is already available on device.
250
+ private func getRemoteModel( modelName: String ,
251
+ progressHandler: ( ( Float ) -> Void ) ? = nil ,
252
+ completion: @escaping ( Result < CustomModel , DownloadError > ) -> Void ) {
253
+ let localModelInfo = getLocalModelInfo ( modelName: modelName)
254
+ let modelInfoRetriever = ModelInfoRetriever (
255
+ modelName: modelName,
256
+ options: options,
257
+ installations: installations,
258
+ appName: appName,
259
+ localModelInfo: localModelInfo
260
+ )
261
+ modelInfoRetriever. downloadModelInfo { result in
262
+ switch result {
263
+ case let . success( downloadModelInfoResult) :
264
+ switch downloadModelInfoResult {
265
+ /// New model info was downloaded from server.
266
+ case let . modelInfo( remoteModelInfo) :
267
+ let downloadTask = ModelDownloadTask (
268
+ remoteModelInfo: remoteModelInfo,
269
+ appName: self . appName,
270
+ defaults: self . userDefaults,
271
+ progressHandler: progressHandler,
272
+ completion: completion
273
+ )
274
+ downloadTask. resumeModelDownload ( )
275
+ /// Local model info is the latest model info.
276
+ case . notModified:
277
+ guard let localModel = self . getLocalModel ( modelName: modelName) else {
278
+ DispatchQueue . main. async {
279
+ /// This can only happen if either local model info or the model file was suddenly wiped out in the middle of model info request and server response
280
+ // TODO: Consider handling: if model file is deleted after local model info is retrieved but before model info network call
281
+ completion (
282
+ . failure(
283
+ . internalError( description: " Model unavailable due to deleted local model info. " )
284
+ )
285
+ )
286
+ }
287
+ return
288
+ }
289
+
290
+ DispatchQueue . main. async {
291
+ completion ( . success( localModel) )
292
+ }
293
+ }
294
+ case let . failure( downloadError) :
295
+ DispatchQueue . main. async {
296
+ completion ( . failure( downloadError) )
297
+ }
298
+ }
299
+ }
300
+ }
301
+ }
302
+
303
+ /// Model downloader extension for testing.
304
+ extension ModelDownloader {
305
+ /// Model downloader instance for testing.
306
+ // TODO: Consider using protocols
307
+ static func modelDownloaderWithDefaults( _ defaults: UserDefaults ,
308
+ app: FirebaseApp ) -> ModelDownloader {
309
+ if let downloader = modelDownloaderDictionary [ app. name] {
310
+ return downloader
311
+ } else {
312
+ let downloader = ModelDownloader ( app: app, defaults: defaults)
313
+ modelDownloaderDictionary [ app. name] = downloader
314
+ return downloader
315
+ }
137
316
}
138
317
}
0 commit comments