15
15
import Foundation
16
16
import FirebaseCore
17
17
18
+ /// Possible states of model downloading.
18
19
enum DownloadStatus {
19
20
case notStarted
20
21
case inProgress
@@ -37,22 +38,29 @@ class DownloadHandlers {
37
38
38
39
/// Manager to handle model downloading device and storing downloaded model info to persistent storage.
39
40
class ModelDownloadTask : NSObject {
41
+ /// Name of the app associated with this instance of ModelDownloadTask.
40
42
private let appName : String
41
- private( set) var modelInfo : ModelInfo
43
+ /// Model info downloaded from server.
44
+ private( set) var remoteModelInfo : RemoteModelInfo
45
+ /// User defaults to which local model info should ultimately be written.
46
+ private let defaults : UserDefaults
47
+ /// Task to handle model file download.
42
48
private var downloadTask : URLSessionDownloadTask ?
49
+ /// Progress and completion handlers associated with this model download task.
43
50
private let downloadHandlers : DownloadHandlers
44
-
51
+ /// Keeps track of download associated with this model download task.
45
52
private( set) var downloadStatus : DownloadStatus = . notStarted
46
-
53
+ /// URLSession to handle model downloads.
47
54
private lazy var downloadSession = URLSession ( configuration: . ephemeral,
48
55
delegate: self ,
49
56
delegateQueue: nil )
50
57
51
- init ( modelInfo : ModelInfo , appName: String ,
58
+ init ( remoteModelInfo : RemoteModelInfo , appName: String , defaults : UserDefaults ,
52
59
progressHandler: DownloadHandlers . ProgressHandler ? = nil ,
53
60
completion: @escaping DownloadHandlers . Completion ) {
54
- self . modelInfo = modelInfo
61
+ self . remoteModelInfo = remoteModelInfo
55
62
self . appName = appName
63
+ self . defaults = defaults
56
64
downloadHandlers = DownloadHandlers (
57
65
progressHandler: progressHandler,
58
66
completion: completion
@@ -62,7 +70,7 @@ class ModelDownloadTask: NSObject {
62
70
/// Asynchronously download model file to device.
63
71
func resumeModelDownload( ) {
64
72
guard downloadStatus == . notStarted else { return }
65
- let downloadTask = downloadSession. downloadTask ( with: modelInfo . downloadURL)
73
+ let downloadTask = downloadSession. downloadTask ( with: remoteModelInfo . downloadURL)
66
74
downloadTask. resume ( )
67
75
downloadStatus = . inProgress
68
76
self . downloadTask = downloadTask
@@ -71,6 +79,11 @@ class ModelDownloadTask: NSObject {
71
79
72
80
/// Extension to handle delegate methods.
73
81
extension ModelDownloadTask : URLSessionDownloadDelegate {
82
+ /// Name for model file stored on device.
83
+ var downloadedModelFileName : String {
84
+ return " fbml_model__ \( appName) __ \( remoteModelInfo. name) "
85
+ }
86
+
74
87
func urlSession( _ session: URLSession ,
75
88
downloadTask: URLSessionDownloadTask ,
76
89
didFinishDownloadingTo location: URL ) {
@@ -80,32 +93,27 @@ extension ModelDownloadTask: URLSessionDownloadDelegate {
80
93
. appendingPathComponent ( downloadedModelFileName)
81
94
do {
82
95
try ModelFileManager . moveFile ( at: location, to: savedURL)
96
+ } catch let error as DownloadError {
97
+ DispatchQueue . main. async {
98
+ self . downloadHandlers
99
+ . completion ( . failure( error) )
100
+ }
83
101
} catch {
84
- downloadHandlers
85
- . completion ( . failure( . internalError( description: error. localizedDescription) ) )
86
- return
102
+ DispatchQueue . main. async {
103
+ self . downloadHandlers
104
+ . completion ( . failure( . internalError( description: error. localizedDescription) ) )
105
+ }
87
106
}
88
107
89
- /// Set path to local model.
90
- modelInfo . path = savedURL. absoluteString
108
+ /// Generate local model info .
109
+ let localModelInfo = LocalModelInfo ( from : remoteModelInfo , path : savedURL. absoluteString)
91
110
/// Write model to user defaults.
92
- do {
93
- try modelInfo. writeToDefaults ( . firebaseMLDefaults, appName: appName)
94
- } catch {
95
- downloadHandlers
96
- . completion ( . failure( . internalError( description: error. localizedDescription) ) )
97
- }
111
+ localModelInfo. writeToDefaults ( defaults, appName: appName)
98
112
/// Build model from model info.
99
- guard let model = buildModel ( ) else {
100
- downloadHandlers
101
- . completion (
102
- . failure(
103
- . internalError( description: " Could not create model due to incomplete model info. " )
104
- )
105
- )
106
- return
113
+ let model = CustomModel ( localModelInfo: localModelInfo)
114
+ DispatchQueue . main. async {
115
+ self . downloadHandlers. completion ( . success( model) )
107
116
}
108
- downloadHandlers. completion ( . success( model) )
109
117
}
110
118
111
119
func urlSession( _ session: URLSession ,
@@ -114,50 +122,11 @@ extension ModelDownloadTask: URLSessionDownloadDelegate {
114
122
totalBytesWritten: Int64 ,
115
123
totalBytesExpectedToWrite: Int64 ) {
116
124
assert ( downloadTask == self . downloadTask)
125
+ /// Check if progress handler is unspecified.
117
126
guard let progressHandler = downloadHandlers. progressHandler else { return }
118
127
let calculatedProgress = Float ( totalBytesWritten) / Float( totalBytesExpectedToWrite)
119
- progressHandler ( calculatedProgress)
120
- }
121
- }
122
-
123
- /// Extension to handle post-download operations.
124
- extension ModelDownloadTask {
125
- var downloadedModelFileName : String {
126
- return " fbml_model__ \( appName) __ \( modelInfo. name) "
127
- }
128
-
129
- /// Build custom model object from model info.
130
- // TODO: Consider moving this to CustomModel as a convenience init
131
- func buildModel( ) -> CustomModel ? {
132
- /// Build custom model only if the model file is already on device.
133
- guard let path = modelInfo. path else { return nil }
134
- let model = CustomModel (
135
- name: modelInfo. name,
136
- size: modelInfo. size,
137
- path: path,
138
- hash: modelInfo. modelHash
139
- )
140
- return model
141
- }
142
-
143
- /// Get the local path to model on device.
144
- func getLocalModelPath( model: CustomModel ) -> URL ? {
145
- let fileURL : URL = ModelFileManager . modelsDirectory
146
- . appendingPathComponent ( downloadedModelFileName)
147
- if ModelFileManager . isFileReachable ( at: fileURL) {
148
- return fileURL
149
- } else {
150
- return nil
128
+ DispatchQueue . main. async {
129
+ progressHandler ( calculatedProgress)
151
130
}
152
131
}
153
132
}
154
-
155
- /// Named user defaults for FirebaseML.
156
- extension UserDefaults {
157
- static var firebaseMLDefaults : UserDefaults {
158
- let suiteName = " com.google.firebase.ml "
159
- // TODO: reconsider force unwrapping
160
- let defaults = UserDefaults ( suiteName: suiteName) !
161
- return defaults
162
- }
163
- }
0 commit comments