diff --git a/Cargo.toml b/Cargo.toml index 12c4056..ff8da31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,8 +18,10 @@ path = "src/bin.rs" [dependencies] futures-util = "0.3.30" -arrow = "52.0.0" -parquet = { version = "52.0.0", features = ["arrow", "async"] } + +chrono = "=0.4.38" +arrow = "=53.2.0" +parquet = { version = "=53.2.0", features = ["arrow", "async"] } axum = "0.7.5" tokio = { version = "1.37.0", features = ["full"] } hyper = { version="1.3.1", features = ["full"] } diff --git a/src/loaders/parquet/helpers.rs b/src/loaders/parquet/helpers.rs index 6770c10..7893cd8 100644 --- a/src/loaders/parquet/helpers.rs +++ b/src/loaders/parquet/helpers.rs @@ -15,43 +15,48 @@ use std::sync::Arc; /// # Returns /// /// This function returns an Arrow Result with the boolean mask. -pub fn create_boolean_mask(batch: &RecordBatch, original_schema: &Arc, filters: Vec<(&str, &str, &str)>) -> arrow::error::Result> { +pub fn create_boolean_mask(batch: &RecordBatch, original_schema: &Arc, filters: Vec>) -> arrow::error::Result> { let num_rows = batch.num_rows(); let mut boolean_builder = BooleanBuilder::new(); - - // Initialize all rows as true - for _ in 0..num_rows { - boolean_builder.append_value(true); - } + boolean_builder.append_n(num_rows, false); let mut boolean_mask = boolean_builder.finish(); + for conjunction in filters.iter() { + let mut conj_boolean_builder = BooleanBuilder::new(); + conj_boolean_builder.append_n(num_rows, true); + let mut conj_boolean_mask = conj_boolean_builder.finish(); + for filter in conjunction.iter() { + let column = batch.column(original_schema.index_of(filter.0).unwrap()); - for filter in filters.iter() { - let column = batch.column(original_schema.index_of(filter.0).unwrap()); - - if column.data_type() == &arrow::datatypes::DataType::Float32 { - let column = column.as_any().downcast_ref::().unwrap(); - apply_filter(&mut boolean_mask, column, filter)?; - } else if column.data_type() == &arrow::datatypes::DataType::Float64 { - let column = column.as_any().downcast_ref::().unwrap(); - apply_filter(&mut boolean_mask, column, filter)?; - } else if column.data_type() == &arrow::datatypes::DataType::Int16 { - let column = column.as_any().downcast_ref::().unwrap(); - apply_filter(&mut boolean_mask, column, filter)?; - } else if column.data_type() == &arrow::datatypes::DataType::Int32 { - let column = column.as_any().downcast_ref::().unwrap(); - apply_filter(&mut boolean_mask, column, filter)?; - } else if column.data_type() == &arrow::datatypes::DataType::Int64 { - let column = column.as_any().downcast_ref::().unwrap(); - apply_filter(&mut boolean_mask, column, filter)?; - } else if column.data_type() == &arrow::datatypes::DataType::Int8 { - let column = column.as_any().downcast_ref::().unwrap(); - apply_filter(&mut boolean_mask, column, filter)?; - } else if column.data_type() == &arrow::datatypes::DataType::Boolean { - let column = column.as_any().downcast_ref::().unwrap(); - apply_filter(&mut boolean_mask, column, filter)?; - } else { - return Err(arrow::error::ArrowError::NotYetImplemented(format!("Data type {:?} not yet implemented", column.data_type()))); + if column.data_type() == &arrow::datatypes::DataType::Float32 { + let column = column.as_any().downcast_ref::().unwrap(); + apply_filter(&mut conj_boolean_mask, column, filter)?; + } else if column.data_type() == &arrow::datatypes::DataType::Float64 { + let column = column.as_any().downcast_ref::().unwrap(); + apply_filter(&mut conj_boolean_mask, column, filter)?; + } else if column.data_type() == &arrow::datatypes::DataType::Int16 { + let column = column.as_any().downcast_ref::().unwrap(); + apply_filter(&mut conj_boolean_mask, column, filter)?; + } else if column.data_type() == &arrow::datatypes::DataType::Int32 { + let column = column.as_any().downcast_ref::().unwrap(); + apply_filter(&mut conj_boolean_mask, column, filter)?; + } else if column.data_type() == &arrow::datatypes::DataType::Int64 { + let column = column.as_any().downcast_ref::().unwrap(); + apply_filter(&mut conj_boolean_mask, column, filter)?; + } else if column.data_type() == &arrow::datatypes::DataType::Int8 { + let column = column.as_any().downcast_ref::().unwrap(); + apply_filter(&mut conj_boolean_mask, column, filter)?; + } else if column.data_type() == &arrow::datatypes::DataType::Boolean { + let column = column.as_any().downcast_ref::().unwrap(); + apply_filter(&mut conj_boolean_mask, column, filter)?; + } else { + return Err(arrow::error::ArrowError::NotYetImplemented(format!("Data type {:?} not yet implemented", column.data_type()))); + } + } + let mut new_mask = BooleanBuilder::new(); + for (index, val) in conj_boolean_mask.iter().enumerate(){ + new_mask.append_value(boolean_mask.value(index) || val.unwrap()); } + boolean_mask = new_mask.finish(); } Ok(Arc::new(boolean_mask)) } diff --git a/src/loaders/parquet/parse_params.rs b/src/loaders/parquet/parse_params.rs index 3b5c46a..ad4309c 100644 --- a/src/loaders/parquet/parse_params.rs +++ b/src/loaders/parquet/parse_params.rs @@ -1,70 +1,77 @@ -use std::collections::HashMap; use regex::Regex; +use std::collections::HashMap; /// # Arguments -/// +/// /// * `params` - A reference to a HashMap of parameters containing 'columns' key. -/// +/// /// # Returns -/// +/// /// A vector of Polars with the selected columns. pub fn parse_columns_from_params_to_str(params: &HashMap) -> Option> { // Parse columns from params // Initialize a set of columns to return let mut select_cols = if let Some(cols) = params.get("columns") { - cols.split(",").map(|x| x.to_string()).collect::>() + Some(cols.split(",").map(|x| x.to_string()).collect::>()) } else { - Vec::new() + None }; // If filters exist, extract and add filter columns if not already present - if let Some(query) = params.get("filters") { - let re = Regex::new(r"([0-9a-zA-Z_]+)([!<>=]+)([-+]?[0-9]*\.?[0-9]*)").unwrap(); - - for filter in query.split(",") { - if let Some(captures) = re.captures(filter) { - let filter_col = captures.get(1).unwrap().as_str(); - - // Add filter column only if it's not already in select_cols - if !select_cols.contains(&filter_col.to_string()) { - select_cols.push(filter_col.to_string()); + if let Some(cols) = select_cols.as_mut() { + if let Some(query) = params.get("filters") { + let re = Regex::new(r"([0-9a-zA-Z_]+)([!<>=]+)([-+]?[0-9]*\.?[0-9]*)").unwrap(); + + for filter in query.split(",") { + if let Some(captures) = re.captures(filter) { + let filter_col = captures.get(1).unwrap().as_str(); + + // Add filter column only if it's not already in select_cols + if !cols.contains(&filter_col.to_string()) { + cols.push(filter_col.to_string()); + } } } } } - // Return Some(select_cols) if not empty, otherwise None - if !select_cols.is_empty() { - Some(select_cols) - } else { - None - } + return select_cols; } /// # Arguments -/// +/// /// * `params` - A reference to a HashMap of parameters containing 'filters' key. -/// +/// /// # Returns -/// +/// /// A vector of tuples containing the column name, the comparison operator and the value to compare. -pub fn parse_filters(params: &HashMap) -> Option> { - let mut filters = Vec::new(); +pub fn parse_filters(params: &HashMap) -> Option>> { + let mut outer_filters = Vec::new(); if let Some(query) = params.get("filters") { - filters = query.split(",").collect::>(); + outer_filters = query.split(";").collect::>(); } - if filters.len() == 0 { - return None + if outer_filters.len() == 0 { + return None; } - let re = Regex::new(r"([0-9a-zA-Z_]+)([!<>=]+)([-+]?[0-9]*\.?[0-9]*)").unwrap(); let mut filter_vec = Vec::new(); - for filter in filters { - let f_vec = re.captures(filter).unwrap(); - filter_vec.push((f_vec.get(1).unwrap().as_str(), f_vec.get(2).unwrap().as_str(), f_vec.get(3).unwrap().as_str())); + let re = Regex::new(r"([0-9a-zA-Z_]+)([!<>=]+)([-+]?[0-9]*\.?[0-9]*)").unwrap(); + for inner_filters in outer_filters { + let mut filters = Vec::new(); + filters = inner_filters.split(",").collect::>(); + let mut inner_filter_vec = Vec::new(); + for filter in filters { + let f_vec = re.captures(filter).unwrap(); + inner_filter_vec.push(( + f_vec.get(1).unwrap().as_str(), + f_vec.get(2).unwrap().as_str(), + f_vec.get(3).unwrap().as_str(), + )); + } + filter_vec.push(inner_filter_vec); } Some(filter_vec) -} \ No newline at end of file +} diff --git a/src/routes.rs b/src/routes.rs index 7dd7241..abcaa78 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -62,7 +62,10 @@ pub async fn entry_route( // No Range header: Process and return the full Parquet file as before match loaders::parquet::parquet::process_and_return_parquet_file(&file_path.to_str().unwrap(), ¶ms).await { Ok(bytes) => Bytes::from(bytes).into_response(), - Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Failed to load file").into_response(), + Err(e) => { + eprintln!("Application error: {e}"); + return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to load file").into_response(); + }, } } diff --git a/tests/parsers.rs b/tests/parsers.rs index fba86a4..74b10b4 100644 --- a/tests/parsers.rs +++ b/tests/parsers.rs @@ -1,5 +1,16 @@ #[cfg(test)] mod parser { + use lsdb_server::loaders::parquet; + use std::collections::HashMap; - + #[tokio::test] + async fn test_parse_filters() { + let mut params = HashMap::new(); + + params.insert("filters".to_string(), "RA>=30.1241,DEC<=-30.3,RA>30,DEC<=30;RA==1;RA=1,RA!=0".to_string()); + + let filters = parquet::parse_params::parse_filters(¶ms); + println!("{:#?}", filters); + // TODO: Add assertions here to verify the result + } } \ No newline at end of file