@@ -2,24 +2,23 @@ use std::io::{Cursor, Write};
22
33use anyhow:: anyhow;
44use bytes:: Buf ;
5- use cas:: key:: Key ;
6- use cas_types:: { QueryChunkResponse , QueryReconstructionResponse , UploadXorbResponse } ;
7- use reqwest:: {
8- header:: { HeaderMap , HeaderValue } ,
9- StatusCode , Url ,
10- } ;
5+ use bytes:: Bytes ;
6+ use reqwest:: { StatusCode , Url } ;
7+ use reqwest_middleware:: { ClientBuilder , ClientWithMiddleware , Middleware } ;
118use serde:: { de:: DeserializeOwned , Serialize } ;
9+ use tracing:: { debug, warn} ;
1210
13- use bytes:: Bytes ;
11+ use cas:: auth:: AuthConfig ;
12+ use cas:: key:: Key ;
1413use cas_object:: CasObject ;
1514use cas_types:: CASReconstructionTerm ;
16- use tracing:: { debug, warn} ;
17-
18- use crate :: { error:: Result , CasClientError } ;
19-
15+ use cas_types:: { QueryChunkResponse , QueryReconstructionResponse , UploadXorbResponse } ;
16+ use error_printer:: OptionPrinter ;
2017use merklehash:: MerkleHash ;
2118
2219use crate :: Client ;
20+ use crate :: { error:: Result , AuthMiddleware , CasClientError } ;
21+
2322pub const CAS_ENDPOINT : & str = "http://localhost:8080" ;
2423pub const PREFIX_DEFAULT : & str = "default" ;
2524
@@ -84,44 +83,37 @@ impl Client for RemoteClient {
8483}
8584
8685impl RemoteClient {
87- pub async fn from_config ( endpoint : String , token : Option < String > ) -> Self {
86+ pub async fn from_config ( endpoint : String , auth_config : & Option < AuthConfig > ) -> Self {
8887 Self {
89- client : CASAPIClient :: new ( & endpoint, token ) ,
88+ client : CASAPIClient :: new ( & endpoint, auth_config ) ,
9089 }
9190 }
9291}
9392
9493#[ derive( Debug ) ]
9594pub struct CASAPIClient {
96- client : reqwest :: Client ,
95+ client : ClientWithMiddleware ,
9796 endpoint : String ,
98- token : Option < String > ,
9997}
10098
10199impl Default for CASAPIClient {
102100 fn default ( ) -> Self {
103- Self :: new ( CAS_ENDPOINT , None )
101+ Self :: new ( CAS_ENDPOINT , & None )
104102 }
105103}
106104
107105impl CASAPIClient {
108- pub fn new ( endpoint : & str , token : Option < String > ) -> Self {
109- let client = reqwest :: Client :: builder ( ) . build ( ) . unwrap ( ) ;
106+ pub fn new ( endpoint : & str , auth_config : & Option < AuthConfig > ) -> Self {
107+ let client = build_reqwest_client ( auth_config ) . unwrap ( ) ;
110108 Self {
111109 client,
112110 endpoint : endpoint. to_string ( ) ,
113- token,
114111 }
115112 }
116113
117114 pub async fn exists ( & self , key : & Key ) -> Result < bool > {
118115 let url = Url :: parse ( & format ! ( "{}/xorb/{key}" , self . endpoint) ) ?;
119- let response = self
120- . client
121- . head ( url)
122- . headers ( self . request_headers ( ) )
123- . send ( )
124- . await ?;
116+ let response = self . client . head ( url) . send ( ) . await ?;
125117 match response. status ( ) {
126118 StatusCode :: OK => Ok ( true ) ,
127119 StatusCode :: NOT_FOUND => Ok ( false ) ,
@@ -133,12 +125,7 @@ impl CASAPIClient {
133125
134126 pub async fn get_length ( & self , key : & Key ) -> Result < Option < u64 > > {
135127 let url = Url :: parse ( & format ! ( "{}/xorb/{key}" , self . endpoint) ) ?;
136- let response = self
137- . client
138- . head ( url)
139- . headers ( self . request_headers ( ) )
140- . send ( )
141- . await ?;
128+ let response = self . client . head ( url) . send ( ) . await ?;
142129 let status = response. status ( ) ;
143130 if status == StatusCode :: NOT_FOUND {
144131 return Ok ( None ) ;
@@ -189,13 +176,7 @@ impl CASAPIClient {
189176 writer. set_position ( 0 ) ;
190177 let data = writer. into_inner ( ) ;
191178
192- let response = self
193- . client
194- . post ( url)
195- . headers ( self . request_headers ( ) )
196- . body ( data)
197- . send ( )
198- . await ?;
179+ let response = self . client . post ( url) . body ( data) . send ( ) . await ?;
199180 let response_body = response. bytes ( ) . await ?;
200181 let response_parsed: UploadXorbResponse = serde_json:: from_reader ( response_body. reader ( ) ) ?;
201182
@@ -247,12 +228,7 @@ impl CASAPIClient {
247228 file_id. hex( )
248229 ) ) ?;
249230
250- let response = self
251- . client
252- . get ( url)
253- . headers ( self . request_headers ( ) )
254- . send ( )
255- . await ?;
231+ let response = self . client . get ( url) . send ( ) . await ?;
256232 let response_body = response. bytes ( ) . await ?;
257233 let response_parsed: QueryReconstructionResponse =
258234 serde_json:: from_reader ( response_body. reader ( ) ) ?;
@@ -262,29 +238,13 @@ impl CASAPIClient {
262238
263239 pub async fn shard_query_chunk ( & self , key : & Key ) -> Result < QueryChunkResponse > {
264240 let url = Url :: parse ( & format ! ( "{}/chunk/{key}" , self . endpoint) ) ?;
265- let response = self
266- . client
267- . get ( url)
268- . headers ( self . request_headers ( ) )
269- . send ( )
270- . await ?;
241+ let response = self . client . get ( url) . send ( ) . await ?;
271242 let response_body = response. bytes ( ) . await ?;
272243 let response_parsed: QueryChunkResponse = serde_json:: from_reader ( response_body. reader ( ) ) ?;
273244
274245 Ok ( response_parsed)
275246 }
276247
277- fn request_headers ( & self ) -> HeaderMap {
278- let mut headers = HeaderMap :: new ( ) ;
279- if let Some ( tok) = & self . token {
280- headers. insert (
281- "Authorization" ,
282- HeaderValue :: from_str ( & format ! ( "Bearer {}" , tok) ) . unwrap ( ) ,
283- ) ;
284- }
285- headers
286- }
287-
288248 async fn post_json < ReqT , RespT > ( & self , url : Url , request_body : & ReqT ) -> Result < RespT >
289249 where
290250 ReqT : Serialize ,
@@ -330,22 +290,49 @@ async fn get_one(term: &CASReconstructionTerm) -> Result<Bytes> {
330290 Ok ( Bytes :: from ( sliced) )
331291}
332292
293+ /// builds the client to talk to CAS.
294+ pub fn build_reqwest_client (
295+ auth_config : & Option < AuthConfig > ,
296+ ) -> std:: result:: Result < ClientWithMiddleware , reqwest:: Error > {
297+ let auth_middleware = auth_config
298+ . as_ref ( )
299+ . map ( AuthMiddleware :: from)
300+ . info_none ( "CAS auth disabled" ) ;
301+ let reqwest_client = reqwest:: Client :: builder ( ) . build ( ) ?;
302+ Ok ( ClientBuilder :: new ( reqwest_client)
303+ . maybe_with ( auth_middleware)
304+ . build ( ) )
305+ }
306+
307+ /// Helper trait to allow the reqwest_middleware client to optionally add a middleware.
308+ trait OptionalMiddleware {
309+ fn maybe_with < M : Middleware > ( self , middleware : Option < M > ) -> Self ;
310+ }
311+
312+ impl OptionalMiddleware for ClientBuilder {
313+ fn maybe_with < M : Middleware > ( self , middleware : Option < M > ) -> Self {
314+ match middleware {
315+ Some ( m) => self . with ( m) ,
316+ None => self ,
317+ }
318+ }
319+ }
320+
333321#[ cfg( test) ]
334322mod tests {
335-
336- use merkledb:: { prelude:: MerkleDBHighLevelMethodsV1 , Chunk , MerkleMemDB } ;
337- use merklehash:: DataHash ;
338323 use rand:: Rng ;
339324 use tracing_test:: traced_test;
340325
341326 use super :: * ;
327+ use merkledb:: { prelude:: MerkleDBHighLevelMethodsV1 , Chunk , MerkleMemDB } ;
328+ use merklehash:: DataHash ;
342329
343330 #[ ignore]
344331 #[ traced_test]
345332 #[ tokio:: test]
346333 async fn test_basic_put ( ) {
347334 // Arrange
348- let rc = RemoteClient :: from_config ( CAS_ENDPOINT . to_string ( ) , None ) . await ;
335+ let rc = RemoteClient :: from_config ( CAS_ENDPOINT . to_string ( ) , & None ) . await ;
349336 let prefix = PREFIX_DEFAULT ;
350337 let ( hash, data, chunk_boundaries) = gen_dummy_xorb ( 3 , 10248 , true ) ;
351338
0 commit comments