Skip to content

Commit 3a78ec1

Browse files
authored
chore(cubestore): Small refactoring of import process (#7189)
1 parent 0b6755c commit 3a78ec1

File tree

3 files changed

+82
-19
lines changed

3 files changed

+82
-19
lines changed

rust/cubestore/cubestore/src/config/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::config::injection::{DIService, Injector};
1515
use crate::config::processing_loop::ProcessingLoop;
1616
use crate::http::HttpServer;
1717
use crate::import::limits::ConcurrencyLimits;
18-
use crate::import::{ImportService, ImportServiceImpl};
18+
use crate::import::{ImportService, ImportServiceImpl, LocationsValidator, LocationsValidatorImpl};
1919
use crate::metastore::{
2020
BaseRocksStoreFs, MetaStore, MetaStoreRpcClient, RocksMetaStore, RocksStoreConfig,
2121
};
@@ -1960,6 +1960,12 @@ impl Config {
19601960
})
19611961
.await;
19621962

1963+
self.injector
1964+
.register_typed::<dyn LocationsValidator, _, _, _>(async move |_| {
1965+
LocationsValidatorImpl::new()
1966+
})
1967+
.await;
1968+
19631969
self.injector
19641970
.register_typed::<dyn ImportService, _, _, _>(async move |i| {
19651971
ImportServiceImpl::new(
@@ -1969,6 +1975,7 @@ impl Config {
19691975
i.get_service_typed().await,
19701976
i.get_service_typed().await,
19711977
i.get_service_typed().await,
1978+
i.get_service_typed().await,
19721979
)
19731980
})
19741981
.await;

rust/cubestore/cubestore/src/import/mod.rs

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,28 @@ impl<R: AsyncBufRead> Stream for CsvLineStream<R> {
469469
}
470470
}
471471

472+
#[async_trait]
473+
pub trait LocationsValidator: DIService + Send + Sync {
474+
async fn validate(&self, locations: &Vec<String>) -> Result<(), CubeError>;
475+
}
476+
477+
pub struct LocationsValidatorImpl;
478+
479+
#[async_trait]
480+
impl LocationsValidator for LocationsValidatorImpl {
481+
async fn validate(&self, _locations: &Vec<String>) -> Result<(), CubeError> {
482+
Ok(())
483+
}
484+
}
485+
486+
impl LocationsValidatorImpl {
487+
pub fn new() -> Arc<Self> {
488+
Arc::new(Self {})
489+
}
490+
}
491+
492+
crate::di_service!(LocationsValidatorImpl, [LocationsValidator]);
493+
472494
#[automock]
473495
#[async_trait]
474496
pub trait ImportService: DIService + Send + Sync {
@@ -482,6 +504,7 @@ pub trait ImportService: DIService + Send + Sync {
482504
async fn validate_table_location(&self, table_id: u64, location: &str)
483505
-> Result<(), CubeError>;
484506
async fn estimate_location_row_count(&self, location: &str) -> Result<u64, CubeError>;
507+
async fn validate_locations_size(&self, locations: &Vec<String>) -> Result<(), CubeError>;
485508
}
486509

487510
crate::di_service!(MockImportService, [ImportService]);
@@ -493,6 +516,7 @@ pub struct ImportServiceImpl {
493516
remote_fs: Arc<dyn RemoteFs>,
494517
config_obj: Arc<dyn ConfigObj>,
495518
limits: Arc<ConcurrencyLimits>,
519+
validator: Arc<dyn LocationsValidator>,
496520
}
497521

498522
crate::di_service!(ImportServiceImpl, [ImportService]);
@@ -505,6 +529,7 @@ impl ImportServiceImpl {
505529
remote_fs: Arc<dyn RemoteFs>,
506530
config_obj: Arc<dyn ConfigObj>,
507531
limits: Arc<ConcurrencyLimits>,
532+
validator: Arc<dyn LocationsValidator>,
508533
) -> Arc<ImportServiceImpl> {
509534
Arc::new(ImportServiceImpl {
510535
meta_store,
@@ -513,6 +538,7 @@ impl ImportServiceImpl {
513538
remote_fs,
514539
config_obj,
515540
limits,
541+
validator,
516542
})
517543
}
518544

@@ -618,23 +644,19 @@ impl ImportServiceImpl {
618644
}
619645

620646
async fn download_temp_file(&self, location: &str) -> Result<File, CubeError> {
621-
let to_download = ImportServiceImpl::temp_uploads_path(location);
647+
let to_download = LocationHelper::temp_uploads_path(location);
622648
// TODO check file size
623649
let local_file = self.remote_fs.download_file(&to_download, None).await?;
624650
Ok(File::open(local_file.clone())
625651
.await
626652
.map_err(|e| CubeError::internal(format!("Open temp_file {}: {}", local_file, e)))?)
627653
}
628654

629-
fn temp_uploads_path(location: &str) -> String {
630-
location.replace("temp://", "temp-uploads/")
631-
}
632-
633655
async fn drop_temp_uploads(&self, location: &str) -> Result<(), CubeError> {
634656
// TODO There also should be a process which collects orphaned uploads due to failed imports
635657
if location.starts_with("temp://") {
636658
self.remote_fs
637-
.delete_file(&ImportServiceImpl::temp_uploads_path(location))
659+
.delete_file(&LocationHelper::temp_uploads_path(location))
638660
.await?;
639661
}
640662
Ok(())
@@ -810,28 +832,57 @@ impl ImportService for ImportServiceImpl {
810832
}
811833

812834
async fn estimate_location_row_count(&self, location: &str) -> Result<u64, CubeError> {
813-
if location.starts_with("http") {
835+
let file_size =
836+
LocationHelper::location_file_size(location, self.remote_fs.clone()).await?;
837+
Ok(ImportServiceImpl::estimate_rows(location, file_size))
838+
}
839+
840+
async fn validate_locations_size(&self, locations: &Vec<String>) -> Result<(), CubeError> {
841+
self.validator.validate(locations).await
842+
}
843+
}
844+
845+
pub struct LocationHelper;
846+
847+
impl LocationHelper {
848+
pub async fn location_file_size(
849+
location: &str,
850+
remote_fs: Arc<dyn RemoteFs>,
851+
) -> Result<Option<u64>, CubeError> {
852+
let res = if location.starts_with("http") {
814853
let client = reqwest::Client::new();
815854
let res = client.head(location).send().await?;
816855
let length = res.headers().get(reqwest::header::CONTENT_LENGTH);
817856

818-
let size = if let Some(length) = length {
857+
if let Some(length) = length {
819858
Some(length.to_str()?.parse::<u64>()?)
820859
} else {
821860
None
822-
};
823-
Ok(ImportServiceImpl::estimate_rows(location, size))
861+
}
824862
} else if location.starts_with("temp://") {
825-
// TODO do the actual estimation
826-
Ok(ImportServiceImpl::estimate_rows(location, None))
863+
let remote_path = Self::temp_uploads_path(location);
864+
match remote_fs.list_with_metadata(&remote_path).await {
865+
Ok(list) => {
866+
let list_res = list.iter().next().ok_or(CubeError::internal(format!(
867+
"Location {} can't be listed in remote_fs",
868+
location
869+
)));
870+
match list_res {
871+
Ok(file) => Ok(Some(file.file_size)),
872+
Err(e) => Err(e),
873+
}
874+
}
875+
Err(e) => Err(e),
876+
}?
827877
} else if location.starts_with("stream://") {
828-
Ok(ImportServiceImpl::estimate_rows(location, None))
878+
None
829879
} else {
830-
Ok(ImportServiceImpl::estimate_rows(
831-
location,
832-
Some(tokio::fs::metadata(location).await?.len()),
833-
))
834-
}
880+
Some(tokio::fs::metadata(location).await?.len())
881+
};
882+
Ok(res)
883+
}
884+
pub fn temp_uploads_path(location: &str) -> String {
885+
location.replace("temp://", "temp-uploads/")
835886
}
836887
}
837888

rust/cubestore/cubestore/src/sql/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,11 @@ impl SqlServiceImpl {
363363
}
364364

365365
let listener = self.cluster.job_result_listener();
366+
if let Some(locations) = locations.as_ref() {
367+
self.import_service
368+
.validate_locations_size(locations)
369+
.await?;
370+
}
366371

367372
let partition_split_threshold = if let Some(locations) = locations.as_ref() {
368373
let size = join_all(

0 commit comments

Comments
 (0)