@@ -469,6 +469,28 @@ impl<R: AsyncBufRead> Stream for CsvLineStream<R> {
469469 }
470470}
471471
472+ #[ async_trait]
473+ pub trait LocationsValidator : DIService + Send + Sync {
474+ async fn validate ( & self , locations : & Vec < String > ) -> Result < ( ) , CubeError > ;
475+ }
476+
477+ pub struct LocationsValidatorImpl ;
478+
479+ #[ async_trait]
480+ impl LocationsValidator for LocationsValidatorImpl {
481+ async fn validate ( & self , _locations : & Vec < String > ) -> Result < ( ) , CubeError > {
482+ Ok ( ( ) )
483+ }
484+ }
485+
486+ impl LocationsValidatorImpl {
487+ pub fn new ( ) -> Arc < Self > {
488+ Arc :: new ( Self { } )
489+ }
490+ }
491+
492+ crate :: di_service!( LocationsValidatorImpl , [ LocationsValidator ] ) ;
493+
472494#[ automock]
473495#[ async_trait]
474496pub trait ImportService : DIService + Send + Sync {
@@ -482,6 +504,7 @@ pub trait ImportService: DIService + Send + Sync {
482504 async fn validate_table_location ( & self , table_id : u64 , location : & str )
483505 -> Result < ( ) , CubeError > ;
484506 async fn estimate_location_row_count ( & self , location : & str ) -> Result < u64 , CubeError > ;
507+ async fn validate_locations_size ( & self , locations : & Vec < String > ) -> Result < ( ) , CubeError > ;
485508}
486509
487510crate :: di_service!( MockImportService , [ ImportService ] ) ;
@@ -493,6 +516,7 @@ pub struct ImportServiceImpl {
493516 remote_fs : Arc < dyn RemoteFs > ,
494517 config_obj : Arc < dyn ConfigObj > ,
495518 limits : Arc < ConcurrencyLimits > ,
519+ validator : Arc < dyn LocationsValidator > ,
496520}
497521
498522crate :: di_service!( ImportServiceImpl , [ ImportService ] ) ;
@@ -505,6 +529,7 @@ impl ImportServiceImpl {
505529 remote_fs : Arc < dyn RemoteFs > ,
506530 config_obj : Arc < dyn ConfigObj > ,
507531 limits : Arc < ConcurrencyLimits > ,
532+ validator : Arc < dyn LocationsValidator > ,
508533 ) -> Arc < ImportServiceImpl > {
509534 Arc :: new ( ImportServiceImpl {
510535 meta_store,
@@ -513,6 +538,7 @@ impl ImportServiceImpl {
513538 remote_fs,
514539 config_obj,
515540 limits,
541+ validator,
516542 } )
517543 }
518544
@@ -618,23 +644,19 @@ impl ImportServiceImpl {
618644 }
619645
620646 async fn download_temp_file ( & self , location : & str ) -> Result < File , CubeError > {
621- let to_download = ImportServiceImpl :: temp_uploads_path ( location) ;
647+ let to_download = LocationHelper :: temp_uploads_path ( location) ;
622648 // TODO check file size
623649 let local_file = self . remote_fs . download_file ( & to_download, None ) . await ?;
624650 Ok ( File :: open ( local_file. clone ( ) )
625651 . await
626652 . map_err ( |e| CubeError :: internal ( format ! ( "Open temp_file {}: {}" , local_file, e) ) ) ?)
627653 }
628654
629- fn temp_uploads_path ( location : & str ) -> String {
630- location. replace ( "temp://" , "temp-uploads/" )
631- }
632-
633655 async fn drop_temp_uploads ( & self , location : & str ) -> Result < ( ) , CubeError > {
634656 // TODO There also should be a process which collects orphaned uploads due to failed imports
635657 if location. starts_with ( "temp://" ) {
636658 self . remote_fs
637- . delete_file ( & ImportServiceImpl :: temp_uploads_path ( location) )
659+ . delete_file ( & LocationHelper :: temp_uploads_path ( location) )
638660 . await ?;
639661 }
640662 Ok ( ( ) )
@@ -810,28 +832,57 @@ impl ImportService for ImportServiceImpl {
810832 }
811833
812834 async fn estimate_location_row_count ( & self , location : & str ) -> Result < u64 , CubeError > {
813- if location. starts_with ( "http" ) {
835+ let file_size =
836+ LocationHelper :: location_file_size ( location, self . remote_fs . clone ( ) ) . await ?;
837+ Ok ( ImportServiceImpl :: estimate_rows ( location, file_size) )
838+ }
839+
840+ async fn validate_locations_size ( & self , locations : & Vec < String > ) -> Result < ( ) , CubeError > {
841+ self . validator . validate ( locations) . await
842+ }
843+ }
844+
845+ pub struct LocationHelper ;
846+
847+ impl LocationHelper {
848+ pub async fn location_file_size (
849+ location : & str ,
850+ remote_fs : Arc < dyn RemoteFs > ,
851+ ) -> Result < Option < u64 > , CubeError > {
852+ let res = if location. starts_with ( "http" ) {
814853 let client = reqwest:: Client :: new ( ) ;
815854 let res = client. head ( location) . send ( ) . await ?;
816855 let length = res. headers ( ) . get ( reqwest:: header:: CONTENT_LENGTH ) ;
817856
818- let size = if let Some ( length) = length {
857+ if let Some ( length) = length {
819858 Some ( length. to_str ( ) ?. parse :: < u64 > ( ) ?)
820859 } else {
821860 None
822- } ;
823- Ok ( ImportServiceImpl :: estimate_rows ( location, size) )
861+ }
824862 } else if location. starts_with ( "temp://" ) {
825- // TODO do the actual estimation
826- Ok ( ImportServiceImpl :: estimate_rows ( location, None ) )
863+ let remote_path = Self :: temp_uploads_path ( location) ;
864+ match remote_fs. list_with_metadata ( & remote_path) . await {
865+ Ok ( list) => {
866+ let list_res = list. iter ( ) . next ( ) . ok_or ( CubeError :: internal ( format ! (
867+ "Location {} can't be listed in remote_fs" ,
868+ location
869+ ) ) ) ;
870+ match list_res {
871+ Ok ( file) => Ok ( Some ( file. file_size ) ) ,
872+ Err ( e) => Err ( e) ,
873+ }
874+ }
875+ Err ( e) => Err ( e) ,
876+ } ?
827877 } else if location. starts_with ( "stream://" ) {
828- Ok ( ImportServiceImpl :: estimate_rows ( location , None ) )
878+ None
829879 } else {
830- Ok ( ImportServiceImpl :: estimate_rows (
831- location,
832- Some ( tokio:: fs:: metadata ( location) . await ?. len ( ) ) ,
833- ) )
834- }
880+ Some ( tokio:: fs:: metadata ( location) . await ?. len ( ) )
881+ } ;
882+ Ok ( res)
883+ }
884+ pub fn temp_uploads_path ( location : & str ) -> String {
885+ location. replace ( "temp://" , "temp-uploads/" )
835886 }
836887}
837888
0 commit comments