10
10
use anyhow:: { anyhow, Context } ;
11
11
use async_trait:: async_trait;
12
12
use futures:: StreamExt ;
13
+ use reqwest:: Url ;
13
14
use reqwest:: { Response , StatusCode } ;
14
15
use slog:: { debug, Logger } ;
16
+ use std:: fs;
15
17
use std:: path:: Path ;
18
+ use tokio:: fs:: File ;
19
+ use tokio:: io:: AsyncReadExt ;
16
20
17
21
#[ cfg( test) ]
18
22
use mockall:: automock;
@@ -78,6 +82,74 @@ impl HttpSnapshotDownloader {
78
82
status_code => Err ( anyhow ! ( "Unhandled error {status_code}" ) ) ,
79
83
}
80
84
}
85
+
86
+ fn file_scheme_to_local_path ( file_url : & str ) -> Option < String > {
87
+ Url :: parse ( file_url)
88
+ . ok ( )
89
+ . filter ( |url| url. scheme ( ) == "file" )
90
+ . and_then ( |url| url. to_file_path ( ) . ok ( ) )
91
+ . map ( |path| path. to_string_lossy ( ) . into_owned ( ) )
92
+ }
93
+
94
+ async fn download_local_file < F , Fut > (
95
+ & self ,
96
+ local_path : & str ,
97
+ sender : & flume:: Sender < Vec < u8 > > ,
98
+ report_progress : F ,
99
+ ) -> MithrilResult < ( ) >
100
+ where
101
+ F : Fn ( u64 ) -> Fut ,
102
+ Fut : std:: future:: Future < Output = ( ) > ,
103
+ {
104
+ // Stream the `location` directly from the local filesystem
105
+ let mut downloaded_bytes: u64 = 0 ;
106
+ let mut file = File :: open ( local_path) . await ?;
107
+
108
+ loop {
109
+ // We can either allocate here each time, or clone a shared buffer into sender.
110
+ // A larger read buffer is faster, less context switches:
111
+ let mut buffer = vec ! [ 0 ; 16 * 1024 * 1024 ] ;
112
+ let bytes_read = file. read ( & mut buffer) . await ?;
113
+ if bytes_read == 0 {
114
+ break ;
115
+ }
116
+ buffer. truncate ( bytes_read) ;
117
+ sender. send_async ( buffer) . await . with_context ( || {
118
+ format ! (
119
+ "Local file read: could not write {} bytes to stream." ,
120
+ bytes_read
121
+ )
122
+ } ) ?;
123
+ downloaded_bytes += bytes_read as u64 ;
124
+ report_progress ( downloaded_bytes) . await
125
+ }
126
+ Ok ( ( ) )
127
+ }
128
+
129
+ async fn download_remote_file < F , Fut > (
130
+ & self ,
131
+ location : & str ,
132
+ sender : & flume:: Sender < Vec < u8 > > ,
133
+ report_progress : F ,
134
+ ) -> MithrilResult < ( ) >
135
+ where
136
+ F : Fn ( u64 ) -> Fut ,
137
+ Fut : std:: future:: Future < Output = ( ) > ,
138
+ {
139
+ let mut downloaded_bytes: u64 = 0 ;
140
+ let mut remote_stream = self . get ( location) . await ?. bytes_stream ( ) ;
141
+ while let Some ( item) = remote_stream. next ( ) . await {
142
+ let chunk = item. with_context ( || "Download: Could not read from byte stream" ) ?;
143
+
144
+ sender. send_async ( chunk. to_vec ( ) ) . await . with_context ( || {
145
+ format ! ( "Download: could not write {} bytes to stream." , chunk. len( ) )
146
+ } ) ?;
147
+
148
+ downloaded_bytes += chunk. len ( ) as u64 ;
149
+ report_progress ( downloaded_bytes) . await
150
+ }
151
+ Ok ( ( ) )
152
+ }
81
153
}
82
154
83
155
#[ cfg_attr( test, automock) ]
@@ -97,8 +169,6 @@ impl SnapshotDownloader for HttpSnapshotDownloader {
97
169
. context ( "Download-Unpack: prerequisite error" ) ,
98
170
) ?;
99
171
}
100
- let mut downloaded_bytes: u64 = 0 ;
101
- let mut remote_stream = self . get ( location) . await ?. bytes_stream ( ) ;
102
172
let ( sender, receiver) = flume:: bounded ( 5 ) ;
103
173
104
174
let dest_dir = target_dir. to_path_buf ( ) ;
@@ -107,21 +177,22 @@ impl SnapshotDownloader for HttpSnapshotDownloader {
107
177
unpacker. unpack_snapshot ( receiver, compression_algorithm, & dest_dir)
108
178
} ) ;
109
179
110
- while let Some ( item) = remote_stream. next ( ) . await {
111
- let chunk = item. with_context ( || "Download: Could not read from byte stream" ) ?;
112
-
113
- sender. send_async ( chunk. to_vec ( ) ) . await . with_context ( || {
114
- format ! ( "Download: could not write {} bytes to stream." , chunk. len( ) )
115
- } ) ?;
116
-
117
- downloaded_bytes += chunk. len ( ) as u64 ;
180
+ let report_progress = |downloaded_bytes : u64 | async move {
118
181
self . feedback_sender
119
182
. send_event ( MithrilEvent :: SnapshotDownloadProgress {
120
183
download_id : download_id. to_owned ( ) ,
121
184
downloaded_bytes,
122
185
size : snapshot_size,
123
186
} )
124
187
. await
188
+ } ;
189
+
190
+ if let Some ( local_path) = Self :: file_scheme_to_local_path ( location) {
191
+ self . download_local_file ( & local_path, & sender, report_progress)
192
+ . await ?;
193
+ } else {
194
+ self . download_remote_file ( location, & sender, report_progress)
195
+ . await ?;
125
196
}
126
197
127
198
drop ( sender) ; // Signal EOF
@@ -143,15 +214,21 @@ impl SnapshotDownloader for HttpSnapshotDownloader {
143
214
async fn probe ( & self , location : & str ) -> MithrilResult < ( ) > {
144
215
debug ! ( self . logger, "HEAD Snapshot location='{location}'." ) ;
145
216
146
- let request_builder = self . http_client . head ( location) ;
147
- let response = request_builder. send ( ) . await . with_context ( || {
148
- format ! ( "Cannot perform a HEAD for snapshot at location='{location}'" )
149
- } ) ?;
217
+ if let Some ( local_path) = Self :: file_scheme_to_local_path ( location) {
218
+ fs:: metadata ( local_path)
219
+ . with_context ( || format ! ( "Local snapshot location='{location}' not found" ) )
220
+ . map ( drop)
221
+ } else {
222
+ let request_builder = self . http_client . head ( location) ;
223
+ let response = request_builder. send ( ) . await . with_context ( || {
224
+ format ! ( "Cannot perform a HEAD for snapshot at location='{location}'" )
225
+ } ) ?;
150
226
151
- match response. status ( ) {
152
- StatusCode :: OK => Ok ( ( ) ) ,
153
- StatusCode :: NOT_FOUND => Err ( anyhow ! ( "Snapshot location='{location} not found" ) ) ,
154
- status_code => Err ( anyhow ! ( "Unhandled error {status_code}" ) ) ,
227
+ match response. status ( ) {
228
+ StatusCode :: OK => Ok ( ( ) ) ,
229
+ StatusCode :: NOT_FOUND => Err ( anyhow ! ( "Snapshot location='{location} not found" ) ) ,
230
+ status_code => Err ( anyhow ! ( "Unhandled error {status_code}" ) ) ,
231
+ }
155
232
}
156
233
}
157
234
}
0 commit comments