diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 79584efd..baaed306 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -50,6 +50,9 @@ jobs: - name: Build with only mysql run: cargo check --no-default-features --features mysql + - name: Build with only trino + run: cargo check --no-default-features --features trino + integration-test-mysql: name: Tests mysql runs-on: ubuntu-latest @@ -154,3 +157,24 @@ jobs: - name: Run tests run: cargo test --features mongodb + + integration-test-trino: + name: Tests trino + runs-on: ubuntu-latest + + env: + TRINO_DOCKER_IMAGE: trinodb/trino:latest + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + + - uses: ./.github/actions/setup-integration-test + + - name: Pull the Trino image + run: | + docker pull ${{ env.TRINO_DOCKER_IMAGE }} + + - name: Run tests + run: cargo test --features trino diff --git a/Cargo.toml b/Cargo.toml index 2dbfa41f..691e369f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,6 +67,9 @@ itertools = "0.13.0" dyn-clone = { version = "1.0.17", optional = true } geo-types = "0.7.13" fundu = "2.0.1" +reqwest = { version = "0.12.5", optional = true, features = ["json", "rustls-tls"] } +regex = { version = "1.11.1", optional = true } +chrono-tz = { version = "0.8", optional = true } [dev-dependencies] anyhow = "1.0.86" @@ -82,6 +85,7 @@ tokio-stream = { version = "0.1.15", features = ["net"] } insta = { version = "1.40.0", features = ["filters"] } datafusion-physical-plan = { version = "49.0.2" } tempfile = "3.8.1" +mockito = "1.7.0" [features] mysql = ["dep:mysql_async", "dep:async-stream"] @@ -111,6 +115,14 @@ mongodb = [ "dep:rust_decimal", "dep:num-traits", ] +trino = [ + "dep:arrow-schema", + "dep:async-stream", + "dep:base64", + "dep:regex", + "dep:reqwest", + "dep:chrono-tz", +] [patch.crates-io] datafusion-federation = { git = "https://github.com/spiceai/datafusion-federation.git", rev = "5ad2f52b9bafc6eaa50851f2e1fcf0585fb5184d" } # spiceai-49 @@ -172,3 +184,8 @@ required-features = ["postgres"] name = "mongodb" path = "examples/mongodb.rs" required-features = ["mongodb"] + +[[example]] +name = "trino" +path = "examples/trino.rs" +required-features = ["trino"] diff --git a/README.md b/README.md index 7da1fb03..5e7de842 100644 --- a/README.md +++ b/README.md @@ -132,3 +132,16 @@ EOF # Run from repo folder cargo run -p datafusion-table-providers --example mongodb --features mongodb ``` + +### Trino + +In order to run the Trino example, you need to have a Trino server running. You can use the following command to start a Trino server in a Docker container the example can use: + +```bash +docker run -d --name trino -p 8080:8080 trinodb/trino:latest +# Wait for the Trino server to start +sleep 30 + +# Run from repo folder +cargo run -p datafusion-table-providers --example trino --features trino +``` diff --git a/examples/trino.rs b/examples/trino.rs new file mode 100644 index 00000000..b70d3961 --- /dev/null +++ b/examples/trino.rs @@ -0,0 +1,61 @@ +use std::{collections::HashMap, sync::Arc}; + +use datafusion::prelude::SessionContext; +use datafusion::sql::TableReference; +use datafusion_table_providers::sql::db_connection_pool::trinodbpool::TrinoConnectionPool; +use datafusion_table_providers::trino::TrinoTableFactory; +use datafusion_table_providers::util::secrets::to_secret_map; + +/// This example demonstrates how to: +/// 1. Create a Trino connection pool +/// 2. Create and use TrinoTableFactory to generate TableProvider +/// 3. Use SQL queries to access Trino table data +/// +/// Prerequisites: +/// Start a Trino server using Docker: +/// ```bash +/// docker run -d --name trino -p 8080:8080 trinodb/trino:latest +/// # Wait for the Trino server to start +/// sleep 30 +/// ``` +#[tokio::main] +async fn main() { + // Create Trino connection parameters + let trino_params = to_secret_map(HashMap::from([ + ("host".to_string(), "localhost".to_string()), + ("port".to_string(), "8080".to_string()), + ("catalog".to_string(), "tpch".to_string()), + ("schema".to_string(), "tiny".to_string()), + ("user".to_string(), "test".to_string()), + ("sslmode".to_string(), "disabled".to_string()), + ])); + + // Create Trino connection pool + let trino_pool = Arc::new( + TrinoConnectionPool::new(trino_params) + .await + .expect("unable to create Trino connection pool"), + ); + + // Create Trino table provider factory + let table_factory = TrinoTableFactory::new(trino_pool.clone()); + + // Create DataFusion session context + let ctx = SessionContext::new(); + + // Register the Trino "region" table as "region" + ctx.register_table( + "region", + table_factory + .table_provider(TableReference::bare("region")) + .await + .expect("failed to register table provider"), + ) + .expect("failed to register table"); + + let df = ctx + .sql("SELECT * FROM region") + .await + .expect("select failed"); + df.show().await.expect("show failed"); +} diff --git a/src/lib.rs b/src/lib.rs index a87f2304..95c973b2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,8 @@ pub mod mysql; pub mod postgres; #[cfg(feature = "sqlite")] pub mod sqlite; +#[cfg(feature = "trino")] +pub mod trino; #[derive(Debug, Snafu)] pub enum Error { diff --git a/src/sql/arrow_sql_gen/mod.rs b/src/sql/arrow_sql_gen/mod.rs index a315abf2..179ed94c 100644 --- a/src/sql/arrow_sql_gen/mod.rs +++ b/src/sql/arrow_sql_gen/mod.rs @@ -50,3 +50,5 @@ pub mod postgres; #[cfg(feature = "sqlite")] pub mod sqlite; pub mod statement; +#[cfg(feature = "trino")] +pub mod trino; diff --git a/src/sql/arrow_sql_gen/trino.rs b/src/sql/arrow_sql_gen/trino.rs new file mode 100644 index 00000000..e754fc6c --- /dev/null +++ b/src/sql/arrow_sql_gen/trino.rs @@ -0,0 +1,47 @@ +use snafu::Snafu; + +pub mod arrow; +pub mod schema; + +pub type Result = std::result::Result; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Failed to build record batch: {source}"))] + FailedToBuildRecordBatch { + source: datafusion::arrow::error::ArrowError, + }, + + #[snafu(display("Failed to find field {column_name} in schema"))] + FailedToFindFieldInSchema { column_name: String }, + + #[snafu(display("Unsupported Trino type: {trino_type}"))] + UnsupportedTrinoType { trino_type: String }, + + #[snafu(display("Unsupported Arrow type: {arrow_type}"))] + UnsupportedArrowType { arrow_type: String }, + + #[snafu(display("Invalid date value: {value}"))] + InvalidDateValue { value: String }, + + #[snafu(display("Invalid time value: {value}"))] + InvalidTimeValue { value: String }, + + #[snafu(display("Invalid timestamp value: {value}"))] + InvalidTimestampValue { value: String }, + + #[snafu(display("Failed to parse decimal value: {value}"))] + FailedToParseDecimal { value: String }, + + #[snafu(display("Failed to downcast builder to expected type: {expected}"))] + BuilderDowncastError { expected: String }, + + #[snafu(display("Invalid or unsupported data type: {data_type}"))] + InvalidDataType { data_type: String }, + + #[snafu(display("Failed to parse precision from type: '{}'", trino_type))] + InvalidPrecision { trino_type: String }, + + #[snafu(display("Failed to compile regex pattern: {}", source))] + RegexError { source: regex::Error }, +} diff --git a/src/sql/arrow_sql_gen/trino/arrow.rs b/src/sql/arrow_sql_gen/trino/arrow.rs new file mode 100644 index 00000000..6222551e --- /dev/null +++ b/src/sql/arrow_sql_gen/trino/arrow.rs @@ -0,0 +1,2851 @@ +use super::{Error, FailedToBuildRecordBatchSnafu, Result}; +use arrow::{ + array::{ + ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, + Decimal256Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, + Int64Builder, Int8Builder, ListBuilder, NullBuilder, RecordBatch, StringBuilder, + StructBuilder, Time32MillisecondBuilder, Time64MicrosecondBuilder, Time64NanosecondBuilder, + TimestampMicrosecondBuilder, TimestampMillisecondBuilder, TimestampNanosecondBuilder, + }, + datatypes::{i256, DataType, Date32Type, Field, Fields, TimeUnit}, +}; +use arrow_schema::{ArrowError, SchemaRef}; +use base64::engine::general_purpose::STANDARD as BASE64; +use base64::Engine; +use bigdecimal::BigDecimal; +use bigdecimal::ToPrimitive; +use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike, Utc}; +use chrono_tz::Tz; +use serde_json::Value; +use snafu::ResultExt; +use std::any::Any; +use std::str::FromStr; +use std::{collections::HashMap, sync::Arc}; + +pub fn rows_to_arrow(rows: &[Vec], schema: SchemaRef) -> Result { + if rows.is_empty() { + return Ok(RecordBatch::new_empty(Arc::clone(&schema))); + } + + let mut builders = create_builders(&schema, rows.len())?; + + for row in rows { + append_row_to_builders(row, &schema, &mut builders)?; + } + + let arrays = finish_builders(builders, &schema)?; + + RecordBatch::try_new(Arc::clone(&schema), arrays).context(FailedToBuildRecordBatchSnafu) +} + +type BuilderMap = HashMap>; + +fn create_builders(schema: &SchemaRef, capacity: usize) -> Result { + let mut builders: BuilderMap = HashMap::new(); + + for field in schema.fields() { + let builder: Box = create_arrow_builder_for_field(field, capacity)?; + builders.insert(field.name().clone(), builder); + } + + Ok(builders) +} + +fn create_arrow_builder_for_field(field: &Field, capacity: usize) -> Result> { + match field.data_type() { + DataType::Null => Ok(Box::new(NullBuilder::new())), + DataType::Boolean => Ok(Box::new(BooleanBuilder::with_capacity(capacity))), + DataType::Int8 => Ok(Box::new(Int8Builder::with_capacity(capacity))), + DataType::Int16 => Ok(Box::new(Int16Builder::with_capacity(capacity))), + DataType::Int32 => Ok(Box::new(Int32Builder::with_capacity(capacity))), + DataType::Int64 => Ok(Box::new(Int64Builder::with_capacity(capacity))), + DataType::Float32 => Ok(Box::new(Float32Builder::with_capacity(capacity))), + DataType::Float64 => Ok(Box::new(Float64Builder::with_capacity(capacity))), + DataType::Decimal128(precision, scale) => { + let builder = Decimal128BuilderWrapper::new(capacity, *precision, *scale) + .map_err(|e| Error::FailedToBuildRecordBatch { source: e })?; + Ok(Box::new(builder)) + } + DataType::Decimal256(precision, scale) => { + let builder = Decimal256BuilderWrapper::new(capacity, *precision, *scale) + .map_err(|e| Error::FailedToBuildRecordBatch { source: e })?; + Ok(Box::new(builder)) + } + DataType::Utf8 => Ok(Box::new(StringBuilder::with_capacity(capacity, 1024))), + DataType::Binary => Ok(Box::new(BinaryBuilder::with_capacity(capacity, 1024))), + DataType::Date32 => Ok(Box::new(Date32Builder::with_capacity(capacity))), + DataType::Time32(TimeUnit::Millisecond) => { + Ok(Box::new(Time32MillisecondBuilder::with_capacity(capacity))) + } + DataType::Time64(TimeUnit::Microsecond) => { + Ok(Box::new(Time64MicrosecondBuilder::with_capacity(capacity))) + } + DataType::Time64(TimeUnit::Nanosecond) => { + Ok(Box::new(Time64NanosecondBuilder::with_capacity(capacity))) + } + DataType::Timestamp(time_unit, tz_opt) => match time_unit { + TimeUnit::Second | TimeUnit::Millisecond => Ok(Box::new( + TimestampMillisecondBuilder::with_capacity(capacity) + .with_timezone_opt(tz_opt.clone()), + )), + TimeUnit::Microsecond => Ok(Box::new( + TimestampMicrosecondBuilder::with_capacity(capacity) + .with_timezone_opt(tz_opt.clone()), + )), + TimeUnit::Nanosecond => Ok(Box::new( + TimestampNanosecondBuilder::with_capacity(capacity) + .with_timezone_opt(tz_opt.clone()), + )), + }, + DataType::List(field) => create_list_builder_for_field(field, capacity), + DataType::Struct(fields) => { + let mut field_builders = Vec::new(); + for field in fields { + field_builders.push(create_arrow_builder_for_field(field, capacity)?); + } + Ok(Box::new(StructBuilder::new(fields.clone(), field_builders))) + } + arrow_type => Err(Error::UnsupportedArrowType { + arrow_type: arrow_type.to_string(), + }), + } +} + +struct Decimal128BuilderWrapper { + inner: Box, + precision: u8, + scale: i8, +} + +impl Decimal128BuilderWrapper { + fn new(capacity: usize, precision: u8, scale: i8) -> std::result::Result { + let inner = Decimal128Builder::with_capacity(capacity) + .with_precision_and_scale(precision, scale)?; + + Ok(Self { + inner: Box::new(inner), + precision, + scale, + }) + } + + fn append_value(&mut self, value: i128) { + self.inner.append_value(value); + } + + fn append_null(&mut self) { + self.inner.append_null(); + } + + fn data_type(&self) -> DataType { + DataType::Decimal128(self.precision, self.scale) + } +} + +impl ArrayBuilder for Decimal128BuilderWrapper { + fn len(&self) -> usize { + self.inner.len() + } + + fn finish(&mut self) -> ArrayRef { + Arc::new(self.inner.finish()) + } + + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.inner.finish_cloned()) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn into_box_any(self: Box) -> Box { + self + } +} + +struct Decimal256BuilderWrapper { + inner: Box, + precision: u8, + scale: i8, +} + +impl Decimal256BuilderWrapper { + fn new(capacity: usize, precision: u8, scale: i8) -> std::result::Result { + let inner = Decimal256Builder::with_capacity(capacity) + .with_precision_and_scale(precision, scale)?; + + Ok(Self { + inner: Box::new(inner), + precision, + scale, + }) + } + + fn append_value(&mut self, value: i256) { + self.inner.append_value(value); + } + + fn append_null(&mut self) { + self.inner.append_null(); + } + + fn data_type(&self) -> DataType { + DataType::Decimal256(self.precision, self.scale) + } +} + +impl ArrayBuilder for Decimal256BuilderWrapper { + fn len(&self) -> usize { + self.inner.len() + } + + fn finish(&mut self) -> ArrayRef { + Arc::new(self.inner.finish()) + } + + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.inner.finish_cloned()) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn into_box_any(self: Box) -> Box { + self + } +} + +fn create_list_builder_for_field( + inner_field: &Field, + capacity: usize, +) -> Result> { + match inner_field.data_type() { + DataType::Null => { + let values_builder = NullBuilder::new(); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Boolean => { + let values_builder: Box = + Box::new(BooleanBuilder::with_capacity(capacity * 4)); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Int8 => { + let values_builder: Box = + Box::new(Int8Builder::with_capacity(capacity * 4)); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Int16 => { + let values_builder = Int16Builder::with_capacity(capacity * 4); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Int32 => { + let values_builder = Int32Builder::with_capacity(capacity * 4); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Int64 => { + let values_builder = Int64Builder::with_capacity(capacity * 4); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Float32 => { + let values_builder = Float32Builder::with_capacity(capacity * 4); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Float64 => { + let values_builder = Float64Builder::with_capacity(capacity * 4); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Utf8 => { + let values_builder = StringBuilder::with_capacity(capacity * 4, 1024); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Binary => { + let values_builder = BinaryBuilder::with_capacity(capacity * 4, 1024); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Date32 => { + let values_builder = Date32Builder::with_capacity(capacity * 4); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Time32(TimeUnit::Second | TimeUnit::Millisecond) => { + let values_builder = Time32MillisecondBuilder::with_capacity(capacity * 4); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Time64(TimeUnit::Microsecond) => { + let values_builder = Time64MicrosecondBuilder::with_capacity(capacity * 4); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Time64(TimeUnit::Nanosecond) => { + let values_builder = Time64NanosecondBuilder::with_capacity(capacity * 4); + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Timestamp(time_unit, tz_opt) => match time_unit { + TimeUnit::Second | TimeUnit::Millisecond => { + let values_builder = TimestampMillisecondBuilder::with_capacity(capacity * 4) + .with_timezone_opt(tz_opt.clone()); + Ok(Box::new(ListBuilder::new(values_builder))) + } + TimeUnit::Microsecond => { + let values_builder = TimestampMicrosecondBuilder::with_capacity(capacity * 4) + .with_timezone_opt(tz_opt.clone()); + Ok(Box::new(ListBuilder::new(values_builder))) + } + TimeUnit::Nanosecond => { + let values_builder = TimestampMicrosecondBuilder::with_capacity(capacity * 4) + .with_timezone_opt(tz_opt.clone()); + Ok(Box::new(ListBuilder::new(values_builder))) + } + }, + DataType::Decimal128(precision, scale) => { + let values_builder = Decimal128BuilderWrapper::new(capacity * 4, *precision, *scale) + .map_err(|e| Error::FailedToBuildRecordBatch { source: e })?; + Ok(Box::new(ListBuilder::new(values_builder))) + } + DataType::Decimal256(precision, scale) => { + let values_builder = Decimal256BuilderWrapper::new(capacity * 4, *precision, *scale) + .map_err(|e| Error::FailedToBuildRecordBatch { source: e })?; + Ok(Box::new(ListBuilder::new(values_builder))) + } + arrow_type => Err(Error::UnsupportedArrowType { + arrow_type: arrow_type.to_string(), + }), + } +} + +fn append_row_to_builders( + row: &[Value], + schema: &SchemaRef, + builders: &mut BuilderMap, +) -> Result<()> { + for (field_idx, field) in schema.fields().iter().enumerate() { + let field_name = field.name(); + let value = row.get(field_idx); + + if let Some(builder) = builders.get_mut(field_name) { + append_value_to_builder(builder.as_mut(), value, field.data_type())?; + } + } + Ok(()) +} + +fn append_value_to_builder( + builder: &mut dyn ArrayBuilder, + value: Option<&Value>, + data_type: &DataType, +) -> Result<()> { + match data_type { + DataType::Boolean => { + let bool_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "BooleanBuilder".to_string(), + })?; + match value { + Some(v) if v.is_null() => bool_builder.append_null(), + Some(Value::Bool(b)) => bool_builder.append_value(*b), + Some(_) => bool_builder.append_null(), + None => bool_builder.append_null(), + } + } + DataType::Int8 => { + let int_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Int8Builder".to_string(), + })?; + append_int8_value(int_builder, value); + } + DataType::Int16 => { + let int_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Int16Builder".to_string(), + })?; + append_int16_value(int_builder, value); + } + DataType::Int32 => { + let int_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Int32Builder".to_string(), + })?; + append_int32_value(int_builder, value); + } + DataType::Int64 => { + let int_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Int64Builder".to_string(), + })?; + append_int64_value(int_builder, value); + } + DataType::Float32 => { + let float_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Float32Builder".to_string(), + })?; + append_float32_value(float_builder, value); + } + DataType::Float64 => { + let float_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Float64Builder".to_string(), + })?; + append_float64_value(float_builder, value); + } + DataType::Utf8 => { + let string_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "StringBuilder".to_string(), + })?; + append_string_value(string_builder, value); + } + DataType::Binary => { + let binary_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "BinaryBuilder".to_string(), + })?; + append_binary_value(binary_builder, value)?; + } + DataType::Date32 => { + let date_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Date32Builder".to_string(), + })?; + append_date32_value(date_builder, value)?; + } + DataType::Time32(TimeUnit::Millisecond) => { + let time_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Time32MillisecondBuilder".to_string(), + })?; + append_time32_millisecond_value(time_builder, value)?; + } + DataType::Time64(TimeUnit::Microsecond) => { + let time_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Time64MicrosecondBuilder".to_string(), + })?; + append_time64_microsecond_value(time_builder, value)?; + } + DataType::Time64(TimeUnit::Nanosecond) => { + let time_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Time64NanosecondBuilder".to_string(), + })?; + append_time64_nanosecond_value(time_builder, value)?; + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + let timestamp_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "TimestampMillisecondBuilder".to_string(), + })?; + append_timestamp_millisecond_value(timestamp_builder, value)?; + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let timestamp_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "TimestampMicrosecondBuilder".to_string(), + })?; + append_timestamp_microsecond_value(timestamp_builder, value)?; + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + let timestamp_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "TimestampNanosecondBuilder".to_string(), + })?; + append_timestamp_nanosecond_value(timestamp_builder, value)?; + } + DataType::Decimal128(_, _) => { + let decimal_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Decimal128BuilderWrapper".to_string(), + })?; + append_decimal128_value(decimal_builder, value)?; + } + DataType::Decimal256(_, _) => { + let decimal_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Decimal256BuilderWrapper".to_string(), + })?; + append_decimal256_value(decimal_builder, value)?; + } + DataType::List(_) => { + append_list_value(builder, value)?; + } + DataType::Struct(fields) => { + let struct_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "StructBuilder".to_string(), + })?; + append_struct_value(struct_builder, value, fields)?; + } + DataType::Null => { + let null_builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or_else(|| Error::BuilderDowncastError { + expected: "NullBuilder".to_string(), + })?; + null_builder.append_null(); + } + arrow_type => { + return Err(Error::UnsupportedArrowType { + arrow_type: arrow_type.to_string(), + }); + } + } + Ok(()) +} + +fn append_int8_value(builder: &mut Int8Builder, value: Option<&Value>) { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::Number(n)) if n.is_i64() => { + if let Some(i) = n.as_i64() { + if i >= i8::MIN as i64 && i <= i8::MAX as i64 { + builder.append_value(i as i8); + } else { + builder.append_null(); + } + } else { + builder.append_null(); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } +} + +fn append_int16_value(builder: &mut Int16Builder, value: Option<&Value>) { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::Number(n)) if n.is_i64() => { + if let Some(i) = n.as_i64() { + if i >= i16::MIN as i64 && i <= i16::MAX as i64 { + builder.append_value(i as i16); + } else { + builder.append_null(); + } + } else { + builder.append_null(); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } +} + +fn append_int32_value(builder: &mut Int32Builder, value: Option<&Value>) { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::Number(n)) if n.is_i64() => { + if let Some(i) = n.as_i64() { + if i >= i32::MIN as i64 && i <= i32::MAX as i64 { + builder.append_value(i as i32); + } else { + builder.append_null(); + } + } else { + builder.append_null(); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } +} + +fn append_int64_value(builder: &mut Int64Builder, value: Option<&Value>) { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::Number(n)) if n.is_i64() => { + if let Some(i) = n.as_i64() { + builder.append_value(i); + } else { + builder.append_null(); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } +} + +fn append_float32_value(builder: &mut Float32Builder, value: Option<&Value>) { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::Number(n)) if n.is_f64() => { + if let Some(f) = n.as_f64() { + builder.append_value(f as f32); + } else { + builder.append_null(); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } +} + +fn append_float64_value(builder: &mut Float64Builder, value: Option<&Value>) { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::Number(n)) if n.is_f64() => { + if let Some(f) = n.as_f64() { + builder.append_value(f); + } else { + builder.append_null(); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } +} + +fn append_string_value(builder: &mut StringBuilder, value: Option<&Value>) { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::String(s)) => builder.append_value(s), + Some(other) => { + let str_val = serde_json::to_string(other).unwrap_or_default(); + builder.append_value(&str_val); + } + None => builder.append_null(), + } +} + +fn append_binary_value(builder: &mut BinaryBuilder, value: Option<&Value>) -> Result<()> { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::String(s)) => { + if let Ok(bytes) = BASE64.decode(s) { + builder.append_value(bytes); + } else { + builder.append_value(s.as_bytes()); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } + Ok(()) +} + +fn append_date32_value(builder: &mut Date32Builder, value: Option<&Value>) -> Result<()> { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::String(date_str)) => { + if let Ok(date) = NaiveDate::parse_from_str(date_str, "%Y-%m-%d") { + builder.append_value(Date32Type::from_naive_date(date)); + } else { + return Err(Error::InvalidDateValue { + value: date_str.to_string(), + }); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } + Ok(()) +} + +fn append_time32_millisecond_value( + builder: &mut Time32MillisecondBuilder, + value: Option<&Value>, +) -> Result<()> { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::String(time_str)) => { + if let Ok(time) = NaiveTime::parse_from_str(time_str, "%H:%M:%S%.f") { + let millis = i32::try_from( + i64::from(time.num_seconds_from_midnight()) * 1_000 + + (time.nanosecond() / 1_000_000) as i64, + ) + .map_err(|_| Error::InvalidTimeValue { + value: time_str.to_string(), + })?; + builder.append_value(millis); + } else { + return Err(Error::InvalidTimeValue { + value: time_str.to_string(), + }); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } + Ok(()) +} + +fn append_time64_microsecond_value( + builder: &mut Time64MicrosecondBuilder, + value: Option<&Value>, +) -> Result<()> { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::String(time_str)) => { + if let Ok(time) = NaiveTime::parse_from_str(time_str, "%H:%M:%S%.f") { + let micros = i64::from(time.num_seconds_from_midnight()) * 1_000_000 + + (time.nanosecond() / 1_000) as i64; + builder.append_value(micros); + } else { + return Err(Error::InvalidTimeValue { + value: time_str.to_string(), + }); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } + Ok(()) +} + +fn append_time64_nanosecond_value( + builder: &mut Time64NanosecondBuilder, + value: Option<&Value>, +) -> Result<()> { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::String(time_str)) => { + if let Ok(time) = NaiveTime::parse_from_str(time_str, "%H:%M:%S%.f") { + let nanos = i64::from(time.num_seconds_from_midnight()) * 1_000_000_000 + + i64::from(time.nanosecond()); + builder.append_value(nanos); + } else { + return Err(Error::InvalidTimeValue { + value: time_str.to_string(), + }); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } + Ok(()) +} + +pub fn append_timestamp_millisecond_value( + builder: &mut TimestampMillisecondBuilder, + value: Option<&Value>, +) -> Result<()> { + match value { + Some(v) if v.is_null() => builder.append_null(), + + Some(Value::String(timestamp_str)) => { + let ts = timestamp_str.trim(); + + if let Some(utc_dt) = parse_timestamp_to_utc_datetime(ts)? { + builder.append_value(utc_dt.timestamp_millis()); + } else { + return Err(Error::InvalidTimestampValue { + value: ts.to_string(), + }); + } + } + + Some(_) | None => builder.append_null(), + } + + Ok(()) +} + +pub fn append_timestamp_microsecond_value( + builder: &mut TimestampMicrosecondBuilder, + value: Option<&Value>, +) -> Result<()> { + match value { + Some(v) if v.is_null() => builder.append_null(), + + Some(Value::String(timestamp_str)) => { + let ts = timestamp_str.trim(); + + if let Some(utc_dt) = parse_timestamp_to_utc_datetime(ts)? { + builder.append_value(utc_dt.timestamp_micros()); + } else { + return Err(Error::InvalidTimestampValue { + value: ts.to_string(), + }); + } + } + + Some(_) | None => builder.append_null(), + } + + Ok(()) +} + +pub fn append_timestamp_nanosecond_value( + builder: &mut TimestampNanosecondBuilder, + value: Option<&Value>, +) -> Result<()> { + match value { + Some(v) if v.is_null() => builder.append_null(), + + Some(Value::String(timestamp_str)) => { + let ts = timestamp_str.trim(); + + if let Some(utc_dt) = parse_timestamp_to_utc_datetime(ts)? { + let nanos = + utc_dt + .timestamp_nanos_opt() + .ok_or_else(|| Error::InvalidTimestampValue { + value: ts.to_string(), + })?; + builder.append_value(nanos); + } else { + return Err(Error::InvalidTimestampValue { + value: ts.to_string(), + }); + } + } + + Some(_) | None => builder.append_null(), + } + + Ok(()) +} + +fn parse_timestamp_to_utc_datetime(ts: &str) -> Result>> { + // 1. Try parsing with IANA timezone (e.g., "2023-12-25 15:30:00 America/New_York") + if let Some((datetime_part, tz_part)) = ts.rsplit_once(' ') { + if let Ok(tz) = Tz::from_str(tz_part) { + // Parse the datetime part without timezone + if let Ok(naive_dt) = + NaiveDateTime::parse_from_str(datetime_part, "%Y-%m-%d %H:%M:%S%.f") + { + // Convert naive datetime to the specified timezone, then to UTC + let dt_with_tz = tz.from_local_datetime(&naive_dt).single().ok_or_else(|| { + Error::InvalidTimestampValue { + value: ts.to_string(), + } + })?; + let utc_dt = dt_with_tz.with_timezone(&Utc); + return Ok(Some(utc_dt)); + } + } + } + + // 3. Try parsing with numeric timezone offset (e.g., "+05:30", "-08:00") + if let Ok(dt_with_tz) = DateTime::parse_from_str(ts, "%Y-%m-%d %H:%M:%S%.f %z") { + let utc_dt = dt_with_tz.with_timezone(&Utc); + return Ok(Some(utc_dt)); + } + + // 4. Try parsing with timezone abbreviation (e.g., "PST", "EST") + if let Some((datetime_part, tz_part)) = ts.rsplit_once(' ') { + if let Some(tz) = parse_timezone_abbreviation(tz_part) { + if let Ok(naive_dt) = + NaiveDateTime::parse_from_str(datetime_part, "%Y-%m-%d %H:%M:%S%.f") + { + let dt_with_tz = tz.from_local_datetime(&naive_dt).single().ok_or_else(|| { + Error::InvalidTimestampValue { + value: ts.to_string(), + } + })?; + let utc_dt = dt_with_tz.with_timezone(&Utc); + return Ok(Some(utc_dt)); + } + } + } + + // 5. Fallback: naive datetime (assume UTC) + if let Ok(naive_dt) = NaiveDateTime::parse_from_str(ts, "%Y-%m-%d %H:%M:%S%.f") { + let utc_dt = Utc.from_utc_datetime(&naive_dt); + return Ok(Some(utc_dt)); + } + + // 6. Try ISO 8601 format with timezone + if let Ok(dt) = DateTime::parse_from_rfc3339(ts) { + let utc_dt = dt.with_timezone(&Utc); + return Ok(Some(utc_dt)); + } + + Ok(None) +} + +fn parse_timezone_abbreviation(tz_abbr: &str) -> Option { + // Map common timezone abbreviations to IANA identifiers + match tz_abbr { + "PST" | "PDT" => Some(Tz::America__Los_Angeles), + "MST" | "MDT" => Some(Tz::America__Denver), + "CST" | "CDT" => Some(Tz::America__Chicago), + "EST" | "EDT" => Some(Tz::America__New_York), + "GMT" | "UTC" => Some(Tz::UTC), + "JST" => Some(Tz::Asia__Tokyo), + "CET" | "CEST" => Some(Tz::Europe__Berlin), + "BST" => Some(Tz::Europe__London), + "IST" => Some(Tz::Asia__Kolkata), + "AEST" | "AEDT" => Some(Tz::Australia__Sydney), + "HST" => Some(Tz::Pacific__Honolulu), + _ => None, + } +} + +fn append_decimal128_value( + builder: &mut Decimal128BuilderWrapper, + value: Option<&Value>, +) -> Result<()> { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::String(decimal_str)) => { + if let Ok(big_decimal) = decimal_str.parse::() { + if let DataType::Decimal128(_, scale) = builder.data_type() { + let scale_factor = BigDecimal::from(10_i128.pow(scale as u32)); + let scaled_decimal = big_decimal * scale_factor; + + if let Some(decimal_value) = scaled_decimal.to_i128() { + builder.append_value(decimal_value); + } else { + builder.append_null(); + } + } else if let Some(decimal_value) = big_decimal.to_i128() { + builder.append_value(decimal_value); + } else { + builder.append_null(); + } + } else { + return Err(Error::FailedToParseDecimal { + value: decimal_str.to_string(), + }); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } + Ok(()) +} + +fn append_decimal256_value( + builder: &mut Decimal256BuilderWrapper, + value: Option<&Value>, +) -> Result<()> { + match value { + Some(v) if v.is_null() => builder.append_null(), + Some(Value::String(decimal_str)) => { + if let Ok(big_decimal) = decimal_str.parse::() { + if let DataType::Decimal256(_, scale) = builder.data_type() { + let scale_factor = BigDecimal::from(10_i128.pow(scale as u32)); + let scaled_decimal = big_decimal * scale_factor; + builder.append_value(to_decimal_256(&scaled_decimal)); + } else { + builder.append_value(to_decimal_256(&big_decimal)); + } + } else { + return Err(Error::FailedToParseDecimal { + value: decimal_str.to_string(), + }); + } + } + Some(_) => builder.append_null(), + None => builder.append_null(), + } + Ok(()) +} + +fn append_list_value(builder: &mut dyn ArrayBuilder, value: Option<&Value>) -> Result<()> { + match value { + Some(v) if v.is_null() => { + append_null_to_list_builder(builder)?; + } + Some(Value::Array(arr)) => { + append_array_to_list_builder(builder, arr)?; + } + Some(_) => { + append_null_to_list_builder(builder)?; + } + None => { + append_null_to_list_builder(builder)?; + } + } + Ok(()) +} + +fn append_null_to_list_builder(builder: &mut dyn ArrayBuilder) -> Result<()> { + if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + list_builder.append_null(); + } else { + return Err(Error::BuilderDowncastError { + expected: "ListBuilder".to_string(), + }); + } + Ok(()) +} + +fn append_array_to_list_builder(builder: &mut dyn ArrayBuilder, arr: &Vec) -> Result<()> { + if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_string_value(list_builder.values(), Some(item)); + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + match item { + Value::Bool(b) => list_builder.values().append_value(*b), + Value::Null => list_builder.values().append_null(), + _ => list_builder.values().append_null(), + } + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_int8_value(list_builder.values(), Some(item)); + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_int16_value(list_builder.values(), Some(item)); + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_int32_value(list_builder.values(), Some(item)); + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_int64_value(list_builder.values(), Some(item)); + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_float32_value(list_builder.values(), Some(item)); + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_float64_value(list_builder.values(), Some(item)); + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_binary_value(list_builder.values(), Some(item))?; + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_date32_value(list_builder.values(), Some(item))?; + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_time32_millisecond_value(list_builder.values(), Some(item))?; + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_time64_microsecond_value(list_builder.values(), Some(item))?; + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_time64_nanosecond_value(list_builder.values(), Some(item))?; + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_timestamp_millisecond_value(list_builder.values(), Some(item))?; + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_timestamp_microsecond_value(list_builder.values(), Some(item))?; + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_timestamp_nanosecond_value(list_builder.values(), Some(item))?; + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_decimal128_value(list_builder.values(), Some(item))?; + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for item in arr { + append_decimal256_value(list_builder.values(), Some(item))?; + } + list_builder.append(true); + } else if let Some(list_builder) = builder + .as_any_mut() + .downcast_mut::>() + { + for _ in arr { + list_builder.values().append_null(); + } + list_builder.append(true); + } else { + return Err(Error::BuilderDowncastError { + expected: "ListBuilder".to_string(), + }); + } + Ok(()) +} + +fn finish_builders(mut builders: BuilderMap, schema: &SchemaRef) -> Result> { + let mut arrays = Vec::new(); + + for field in schema.fields() { + let field_name = field.name(); + if let Some(mut builder) = builders.remove(field_name) { + arrays.push(builder.finish()); + } else { + return Err(Error::FailedToFindFieldInSchema { + column_name: field_name.to_string(), + }); + } + } + + Ok(arrays) +} + +fn to_decimal_256(decimal: &BigDecimal) -> i256 { + let (bigint_value, _) = decimal.as_bigint_and_exponent(); + let mut bigint_bytes = bigint_value.to_signed_bytes_le(); + + let is_negative = bigint_value.sign() == num_bigint::Sign::Minus; + let fill_byte = if is_negative { 0xFF } else { 0x00 }; + + if bigint_bytes.len() > 32 { + bigint_bytes.truncate(32); + } else { + bigint_bytes.resize(32, fill_byte); + }; + + let mut array = [0u8; 32]; + array.copy_from_slice(&bigint_bytes); + + i256::from_le_bytes(array) +} + +fn append_to_struct_field_builder( + builder: &mut StructBuilder, + field_index: usize, + value: Option<&Value>, + field_data_type: &DataType, +) -> Result<()> { + match field_data_type { + DataType::Boolean => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "BooleanBuilder".to_string(), + })?; + match value { + Some(v) if v.is_null() => field_builder.append_null(), + Some(Value::Bool(b)) => field_builder.append_value(*b), + _ => field_builder.append_null(), + } + } + DataType::Int8 => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Int8Builder".to_string(), + })?; + append_int8_value(field_builder, value); + } + DataType::Int16 => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Int16Builder".to_string(), + })?; + append_int16_value(field_builder, value); + } + DataType::Int32 => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Int32Builder".to_string(), + })?; + append_int32_value(field_builder, value); + } + DataType::Int64 => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Int64Builder".to_string(), + })?; + append_int64_value(field_builder, value); + } + DataType::Float32 => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Float32Builder".to_string(), + })?; + append_float32_value(field_builder, value); + } + DataType::Float64 => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Float64Builder".to_string(), + })?; + append_float64_value(field_builder, value); + } + DataType::Utf8 => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "StringBuilder".to_string(), + })?; + append_string_value(field_builder, value); + } + DataType::Binary => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "BinaryBuilder".to_string(), + })?; + append_binary_value(field_builder, value)?; + } + DataType::Date32 => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Date32Builder".to_string(), + })?; + append_date32_value(field_builder, value)?; + } + DataType::Time32(TimeUnit::Millisecond) => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Time32MillisecondBuilder".to_string(), + })?; + append_time32_millisecond_value(field_builder, value)?; + } + DataType::Time64(TimeUnit::Microsecond) => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Time64MicrosecondBuilder".to_string(), + })?; + append_time64_microsecond_value(field_builder, value)?; + } + DataType::Time64(TimeUnit::Nanosecond) => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Time64NanosecondBuilder".to_string(), + })?; + append_time64_nanosecond_value(field_builder, value)?; + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "TimestampMillisecondBuilder".to_string(), + })?; + append_timestamp_millisecond_value(field_builder, value)?; + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "TimestampMicrosecondBuilder".to_string(), + })?; + append_timestamp_microsecond_value(field_builder, value)?; + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "TimestampNanosecondBuilder".to_string(), + })?; + append_timestamp_nanosecond_value(field_builder, value)?; + } + DataType::Decimal128(_, _) => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Decimal128BuilderWrapper".to_string(), + })?; + append_decimal128_value(field_builder, value)?; + } + DataType::Decimal256(_, _) => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "Decimal256BuilderWrapper".to_string(), + })?; + append_decimal256_value(field_builder, value)?; + } + DataType::Struct(nested_fields) => { + let nested_struct_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "StructBuilder".to_string(), + })?; + append_struct_value(nested_struct_builder, value, nested_fields)?; + } + DataType::Null => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "NullBuilder".to_string(), + })?; + field_builder.append_null(); + } + _ => { + let field_builder = builder + .field_builder::(field_index) + .ok_or_else(|| Error::BuilderDowncastError { + expected: "StringBuilder (fallback) - 2".to_string(), + })?; + append_string_value(field_builder, value); + } + } + Ok(()) +} + +fn append_struct_value( + builder: &mut StructBuilder, + value: Option<&Value>, + fields: &Fields, +) -> Result<()> { + match value { + Some(v) if v.is_null() => { + for (i, field) in fields.iter().enumerate() { + append_to_struct_field_builder(builder, i, None, field.data_type())?; + } + builder.append_null(); + } + Some(Value::Object(obj)) => { + for (i, field) in fields.iter().enumerate() { + let field_value = obj.get(field.name()); + append_to_struct_field_builder(builder, i, field_value, field.data_type())?; + } + builder.append(true); + } + Some(Value::Array(arr)) => { + for (i, field) in fields.iter().enumerate() { + let field_value = arr.get(i); + append_to_struct_field_builder(builder, i, field_value, field.data_type())?; + } + builder.append(true); + } + Some(_) => { + for (i, field) in fields.iter().enumerate() { + append_to_struct_field_builder(builder, i, None, field.data_type())?; + } + builder.append_null(); + } + None => { + for (i, field) in fields.iter().enumerate() { + append_to_struct_field_builder(builder, i, None, field.data_type())?; + } + builder.append_null(); + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::sql::arrow_sql_gen::trino::schema::trino_data_type_to_arrow_type; + use arrow::array::*; + use arrow::datatypes::Schema; + use serde_json::{json, Value}; + + fn create_test_schema(columns: Vec<(&str, &str)>) -> SchemaRef { + let mut fields = Vec::new(); + for (name, data_type) in columns { + let arrow_type = trino_data_type_to_arrow_type(data_type, None).unwrap(); + fields.push(Field::new(name, arrow_type, true)); + } + + Arc::new(Schema::new(fields)) + } + + fn assert_boolean_array(record_batch: &RecordBatch, column_index: usize, expected: Vec) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_int8_array(record_batch: &RecordBatch, column_index: usize, expected: Vec) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_int16_array(record_batch: &RecordBatch, column_index: usize, expected: Vec) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_int32_array(record_batch: &RecordBatch, column_index: usize, expected: Vec) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_int64_array(record_batch: &RecordBatch, column_index: usize, expected: Vec) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_float32_array(record_batch: &RecordBatch, column_index: usize, expected: Vec) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert!( + (array.value(i) - expected_value).abs() < f32::EPSILON, + "Mismatch at index {}: expected {}, got {}", + i, + expected_value, + array.value(i) + ); + } + } + + fn assert_float64_array(record_batch: &RecordBatch, column_index: usize, expected: Vec) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert!( + (array.value(i) - expected_value).abs() < f64::EPSILON, + "Mismatch at index {}: expected {}, got {}", + i, + expected_value, + array.value(i) + ); + } + } + + fn assert_string_array(record_batch: &RecordBatch, column_index: usize, expected: Vec<&str>) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_int32_array_with_nulls( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec>, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + match expected_value { + Some(val) => { + assert!(!array.is_null(i), "Expected non-null at index {i}"); + assert_eq!(array.value(i), *val, "Mismatch at index {i}"); + } + None => { + assert!(array.is_null(i), "Expected null at index {i}"); + } + } + } + } + + fn assert_string_array_with_nulls( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec>, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + match expected_value { + Some(val) => { + assert!(!array.is_null(i), "Expected non-null at index {i}"); + assert_eq!(array.value(i), *val, "Mismatch at index {i}"); + } + None => { + assert!(array.is_null(i), "Expected null at index {i}"); + } + } + } + } + + fn assert_date32_array(record_batch: &RecordBatch, column_index: usize, expected: Vec) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_time32_millisecond_array( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_time64_microsecond_array( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_time64_nanosecond_array( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_timestamp_millisecond_array( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_timestamp_microsecond_array( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_timestamp_nanosecond_array( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_decimal128_array( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_decimal256_array( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn decimal_to_scaled_int128(decimal_str: &str, scale: u8) -> i128 { + let decimal = decimal_str.parse::().unwrap(); + let scale_factor = 10_i128.pow(scale as u32); + (decimal * bigdecimal::BigDecimal::from(scale_factor)) + .to_i128() + .unwrap() + } + + fn decimal_to_scaled_int256(decimal_str: &str, scale: u8) -> i256 { + let decimal = decimal_str.parse::().unwrap(); + let scale_factor = bigdecimal::BigDecimal::from(10_i128.pow(scale as u32)); + let scaled_decimal = decimal * scale_factor; + + let (bigint_value, _) = scaled_decimal.as_bigint_and_exponent(); + let mut bigint_bytes = bigint_value.to_signed_bytes_le(); + + let is_negative = bigint_value.sign() == num_bigint::Sign::Minus; + let fill_byte = if is_negative { 0xFF } else { 0x00 }; + + if bigint_bytes.len() > 32 { + bigint_bytes.truncate(32); + } else { + bigint_bytes.resize(32, fill_byte); + }; + + let mut array = [0u8; 32]; + array.copy_from_slice(&bigint_bytes); + i256::from_le_bytes(array) + } + + fn assert_binary_array(record_batch: &RecordBatch, column_index: usize, expected: Vec<&[u8]>) { + let column = record_batch.column(column_index); + + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + for (i, expected_value) in expected.iter().enumerate() { + assert_eq!(array.value(i), *expected_value, "Mismatch at index {i}"); + } + } + + fn assert_list_of_strings_array( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec>>, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + + for (i, expected_value) in expected.iter().enumerate() { + match expected_value { + Some(expected_list) => { + assert!(!array.is_null(i), "Expected non-null at index {i}"); + let list_array = array.value(i); + let string_array = list_array.as_any().downcast_ref::().unwrap(); + + assert_eq!( + string_array.len(), + expected_list.len(), + "List length mismatch at index {i}" + ); + + for (j, expected_item) in expected_list.iter().enumerate() { + assert_eq!( + string_array.value(j), + *expected_item, + "Mismatch at index {i} item {j}" + ); + } + } + None => { + assert!(array.is_null(i), "Expected null at index {i}"); + } + } + } + } + + fn assert_list_of_integers_array( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec>>, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + + for (i, expected_value) in expected.iter().enumerate() { + match expected_value { + Some(expected_list) => { + assert!(!array.is_null(i), "Expected non-null at index {i}"); + let list_array = array.value(i); + let string_array = list_array.as_any().downcast_ref::().unwrap(); + + assert_eq!( + string_array.len(), + expected_list.len(), + "List length mismatch at index {i}" + ); + + for (j, expected_item) in expected_list.iter().enumerate() { + assert_eq!( + string_array.value(j), + *expected_item, + "Mismatch at index {i} item {j}" + ); + } + } + None => { + assert!(array.is_null(i), "Expected null at index {i}"); + } + } + } + } + + fn assert_struct_array( + record_batch: &RecordBatch, + column_index: usize, + expected: Vec>, + ) { + let array = record_batch + .column(column_index) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), expected.len(), "Array length mismatch"); + + for (i, expected_value) in expected.iter().enumerate() { + match expected_value { + Some(expected_struct) => { + assert!(!array.is_null(i), "Expected non-null at index {i}"); + + let mut actual_struct = serde_json::Map::new(); + + for (field_idx, field) in array.fields().iter().enumerate() { + let field_array = array.column(field_idx); + let field_name = field.name(); + + let field_value = match field.data_type() { + DataType::Utf8 => { + let string_array = + field_array.as_any().downcast_ref::().unwrap(); + if string_array.is_null(i) { + Value::Null + } else { + Value::String(string_array.value(i).to_string()) + } + } + DataType::Int32 => { + let int_array = + field_array.as_any().downcast_ref::().unwrap(); + if int_array.is_null(i) { + Value::Null + } else { + Value::Number(serde_json::Number::from(int_array.value(i))) + } + } + _ => Value::Null, + }; + + actual_struct.insert(field_name.clone(), field_value); + } + + let actual_json = Value::Object(actual_struct); + assert_eq!( + actual_json, *expected_struct, + "Struct mismatch at index {i}" + ); + } + None => { + assert!(array.is_null(i), "Expected null at index {i}"); + } + } + } + } + + #[test] + fn test_empty_rows_empty_schema() { + let rows: Vec> = vec![]; + let schema = create_test_schema(vec![]); + + let result = rows_to_arrow(&rows, schema).unwrap(); + assert_eq!(result.num_rows(), 0); + assert_eq!(result.num_columns(), 0); + } + + #[test] + fn test_empty_rows_with_columns() { + let rows: Vec> = vec![]; + let schema = create_test_schema(vec![ + ("id", "bigint"), + ("name", "varchar"), + ("active", "boolean"), + ]); + + let result = rows_to_arrow(&rows, schema).unwrap(); + assert_eq!(result.num_rows(), 0); + assert_eq!(result.num_columns(), 3); + + let schema = result.schema(); + assert_eq!(schema.field(0).name(), "id"); + assert_eq!(schema.field(1).name(), "name"); + assert_eq!(schema.field(2).name(), "active"); + } + + #[test] + fn test_basic_data_types() { + let schema = create_test_schema(vec![ + ("bool_col", "boolean"), + ("int8_col", "tinyint"), + ("int16_col", "smallint"), + ("int32_col", "integer"), + ("int64_col", "bigint"), + ("float32_col", "real"), + ("float64_col", "double"), + ("string_col", "varchar"), + ]); + + let rows = vec![ + vec![ + json!(true), + json!(127), + json!(32767), + json!(2147483647), + json!(9223372036854775807i64), + json!(3.14f32), + json!(2.718281828), + json!("hello"), + ], + vec![ + json!(false), + json!(-128), + json!(-32768), + json!(-2147483648), + json!(-9223372036854775808i64), + json!(-1.23f32), + json!(-9.876543210), + json!("world"), + ], + ]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + assert_eq!(result.num_rows(), 2); + assert_eq!(result.num_columns(), 8); + + assert_boolean_array(&result, 0, vec![true, false]); + assert_int8_array(&result, 1, vec![127, -128]); + assert_int16_array(&result, 2, vec![32767, -32768]); + assert_int32_array(&result, 3, vec![2147483647, -2147483648]); + assert_int64_array( + &result, + 4, + vec![9223372036854775807i64, -9223372036854775808i64], + ); + assert_float32_array(&result, 5, vec![3.14f32, -1.23f32]); + assert_float64_array(&result, 6, vec![2.718281828, -9.876543210]); + assert_string_array(&result, 7, vec!["hello", "world"]); + } + + #[test] + fn test_null_values() { + let schema = create_test_schema(vec![ + ("nullable_int", "integer"), + ("nullable_string", "varchar"), + ]); + + let rows = vec![ + vec![json!(42), json!("test")], + vec![Value::Null, Value::Null], + vec![json!(100), json!("another")], + ]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + assert_eq!(result.num_rows(), 3); + + assert_int32_array_with_nulls(&result, 0, vec![Some(42), None, Some(100)]); + assert_string_array_with_nulls(&result, 1, vec![Some("test"), None, Some("another")]); + } + + #[test] + fn test_time() { + let schema = create_test_schema(vec![ + ("col0", "time(0)"), + ("col1", "time(1)"), + ("col2", "time(2)"), + ("col3", "time(3)"), + ("col4", "time(4)"), + ("col5", "time(5)"), + ("col6", "time(6)"), + ("col7", "time(7)"), + ("col8", "time(8)"), + ("col9", "time(9)"), + ]); + + let row = vec![ + json!("14:30:45"), + json!("14:30:45.1"), + json!("14:30:45.12"), + json!("14:30:45.123"), + json!("14:30:45.1234"), + json!("14:30:45.12345"), + json!("14:30:45.123456"), + json!("14:30:45.1234567"), + json!("14:30:45.12345678"), + json!("14:30:45.123456789"), + ]; + + let result = rows_to_arrow(&[row], schema).unwrap(); + + let t = |s| NaiveTime::parse_from_str(s, "%H:%M:%S%.f").unwrap(); + + assert_time32_millisecond_array( + &result, + 0, + vec![t("14:30:45").num_seconds_from_midnight() as i32 * 1000], + ); + assert_time32_millisecond_array( + &result, + 1, + vec![ + t("14:30:45.1").num_seconds_from_midnight() as i32 * 1000 + + (t("14:30:45.1").nanosecond() / 1_000_000) as i32, + ], + ); + assert_time32_millisecond_array( + &result, + 2, + vec![ + t("14:30:45.12").num_seconds_from_midnight() as i32 * 1000 + + (t("14:30:45.12").nanosecond() / 1_000_000) as i32, + ], + ); + assert_time32_millisecond_array( + &result, + 3, + vec![ + t("14:30:45.123").num_seconds_from_midnight() as i32 * 1000 + + (t("14:30:45.123").nanosecond() / 1_000_000) as i32, + ], + ); + + // Microsecond precision + assert_time64_microsecond_array( + &result, + 4, + vec![ + t("14:30:45.1234").num_seconds_from_midnight() as i64 * 1_000_000 + + (t("14:30:45.1234").nanosecond() / 1_000) as i64, + ], + ); + assert_time64_microsecond_array( + &result, + 5, + vec![ + t("14:30:45.12345").num_seconds_from_midnight() as i64 * 1_000_000 + + (t("14:30:45.12345").nanosecond() / 1_000) as i64, + ], + ); + assert_time64_microsecond_array( + &result, + 6, + vec![ + t("14:30:45.123456").num_seconds_from_midnight() as i64 * 1_000_000 + + (t("14:30:45.123456").nanosecond() / 1_000) as i64, + ], + ); + + // Nanosecond precision + assert_time64_nanosecond_array( + &result, + 7, + vec![ + t("14:30:45.1234567").num_seconds_from_midnight() as i64 * 1_000_000_000 + + t("14:30:45.1234567").nanosecond() as i64, + ], + ); + assert_time64_nanosecond_array( + &result, + 8, + vec![ + t("14:30:45.12345678").num_seconds_from_midnight() as i64 * 1_000_000_000 + + t("14:30:45.12345678").nanosecond() as i64, + ], + ); + assert_time64_nanosecond_array( + &result, + 9, + vec![ + t("14:30:45.123456789").num_seconds_from_midnight() as i64 * 1_000_000_000 + + t("14:30:45.123456789").nanosecond() as i64, + ], + ); + } + + #[test] + fn test_timestamp() { + let schema = create_test_schema(vec![ + ("col0", "timestamp(0)"), + ("col1", "timestamp(1)"), + ("col2", "timestamp(2)"), + ("col3", "timestamp(3)"), + ("col4", "timestamp(4)"), + ("col5", "timestamp(5)"), + ("col6", "timestamp(6)"), + ("col7", "timestamp(7)"), + ("col8", "timestamp(8)"), + ("col9", "timestamp(9)"), + ]); + + let row = vec![ + json!("2023-12-25 14:30:45"), + json!("2023-12-25 14:30:45.1"), + json!("2023-12-25 14:30:45.12"), + json!("2023-12-25 14:30:45.123"), + json!("2023-12-25 14:30:45.1234"), + json!("2023-12-25 14:30:45.12345"), + json!("2023-12-25 14:30:45.123456"), + json!("2023-12-25 14:30:45.1234567"), + json!("2023-12-25 14:30:45.12345678"), + json!("2023-12-25 14:30:45.123456789"), + ]; + + let result = rows_to_arrow(&[row], schema).unwrap(); + + let ts = |s: &str| { + chrono::DateTime::parse_from_rfc3339(s) + .unwrap() + .timestamp_nanos_opt() + .unwrap() + }; + + assert_timestamp_millisecond_array( + &result, + 0, + vec![ts("2023-12-25T14:30:45Z") / 1_000_000], + ); + assert_timestamp_millisecond_array( + &result, + 1, + vec![ts("2023-12-25T14:30:45.1Z") / 1_000_000], + ); + assert_timestamp_millisecond_array( + &result, + 2, + vec![ts("2023-12-25T14:30:45.12Z") / 1_000_000], + ); + assert_timestamp_millisecond_array( + &result, + 3, + vec![ts("2023-12-25T14:30:45.123Z") / 1_000_000], + ); + assert_timestamp_microsecond_array( + &result, + 4, + vec![ts("2023-12-25T14:30:45.1234Z") / 1_000], + ); + assert_timestamp_microsecond_array( + &result, + 5, + vec![ts("2023-12-25T14:30:45.12345Z") / 1_000], + ); + assert_timestamp_microsecond_array( + &result, + 6, + vec![ts("2023-12-25T14:30:45.123456Z") / 1_000], + ); + assert_timestamp_nanosecond_array(&result, 7, vec![ts("2023-12-25T14:30:45.1234567Z")]); + assert_timestamp_nanosecond_array(&result, 8, vec![ts("2023-12-25T14:30:45.12345678Z")]); + assert_timestamp_nanosecond_array(&result, 9, vec![ts("2023-12-25T14:30:45.123456789Z")]); + } + + #[test] + fn test_timestamp_with_timezone() { + let schema = create_test_schema(vec![ + ("col0", "timestamp(0) with time zone"), + ("col1", "timestamp(1) with time zone"), + ("col2", "timestamp(2) with time zone"), + ("col3", "timestamp(3) with time zone"), + ("col4", "timestamp(4) with time zone"), + ("col5", "timestamp(5) with time zone"), + ("col6", "timestamp(6) with time zone"), + ("col7", "timestamp(7) with time zone"), + ("col8", "timestamp(8) with time zone"), + ("col9", "timestamp(9) with time zone"), + ]); + + let row = vec![ + json!("2023-12-25 14:30:45 UTC"), + json!("2023-12-25 14:30:45.1 UTC"), + json!("2023-12-25 15:30:45.12 +01:00"), + json!("2023-12-25 06:30:45.123 America/Los_Angeles"), + json!("2023-12-25 14:30:45.1234 UTC"), + json!("2023-12-25 16:30:45.12345 +02:00"), + json!("2023-12-25 09:30:45.123456 America/New_York"), + json!("2023-12-25 14:30:45.1234567 UTC"), + json!("2023-12-25 11:30:45.12345678 -03:00"), + json!("2023-12-25 15:30:45.123456789 Europe/Amsterdam"), + ]; + + let result = rows_to_arrow(&[row], schema).unwrap(); + + let ts_millis = |s: &str| { + chrono::DateTime::parse_from_rfc3339(s) + .unwrap() + .timestamp_millis() + }; + + let ts_micros = |s: &str| { + chrono::DateTime::parse_from_rfc3339(s) + .unwrap() + .timestamp_nanos_opt() + .unwrap() + / 1_000 + }; + + let ts_nanos = |s: &str| { + chrono::DateTime::parse_from_rfc3339(s) + .unwrap() + .timestamp_nanos_opt() + .unwrap() + }; + + assert_timestamp_millisecond_array(&result, 0, vec![ts_millis("2023-12-25T14:30:45Z")]); + assert_timestamp_millisecond_array(&result, 1, vec![ts_millis("2023-12-25T14:30:45.1Z")]); + assert_timestamp_millisecond_array(&result, 2, vec![ts_millis("2023-12-25T14:30:45.12Z")]); + assert_timestamp_millisecond_array(&result, 3, vec![ts_millis("2023-12-25T14:30:45.123Z")]); + assert_timestamp_microsecond_array( + &result, + 4, + vec![ts_micros("2023-12-25T14:30:45.1234Z")], + ); + assert_timestamp_microsecond_array( + &result, + 5, + vec![ts_micros("2023-12-25T14:30:45.12345Z")], + ); + assert_timestamp_microsecond_array( + &result, + 6, + vec![ts_micros("2023-12-25T14:30:45.123456Z")], + ); + assert_timestamp_nanosecond_array( + &result, + 7, + vec![ts_nanos("2023-12-25T14:30:45.1234567Z")], + ); + assert_timestamp_nanosecond_array( + &result, + 8, + vec![ts_nanos("2023-12-25T14:30:45.12345678Z")], + ); + assert_timestamp_nanosecond_array( + &result, + 9, + vec![ts_nanos("2023-12-25T14:30:45.123456789Z")], + ); + } + + #[test] + fn test_date() { + let schema = create_test_schema(vec![("date_col", "date")]); + + let rows = vec![vec![json!("2023-12-25")], vec![json!("1970-01-01")]]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + assert_eq!(result.num_rows(), 2); + assert_eq!(result.num_columns(), 1); + + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + let date1 = NaiveDate::from_ymd_opt(2023, 12, 25) + .unwrap() + .signed_duration_since(epoch) + .num_days() as i32; + let date2 = 0; + + assert_date32_array(&result, 0, vec![date1, date2]); + } + + #[test] + fn test_invalid_date_format() { + let schema = create_test_schema(vec![("date_col", "date")]); + let rows = vec![vec![json!("invalid-date")]]; + + let result = rows_to_arrow(&rows, schema); + assert!(result.is_err()); + } + + #[test] + fn test_decimal_types() { + let schema = create_test_schema(vec![ + ("decimal128_col", "decimal(10,2)"), + ("decimal256_col", "decimal(42,4)"), + ]); + + let rows = vec![ + vec![json!("123.45"), json!("999999999999.9999")], + vec![json!("0.00"), json!("0.0000")], + ]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + assert_eq!(result.num_rows(), 2); + assert_eq!(result.num_columns(), 2); + + let decimal128_1 = decimal_to_scaled_int128("123.45", 2); + let decimal128_2 = decimal_to_scaled_int128("0.00", 2); + + let decimal256_1 = decimal_to_scaled_int256("999999999999.9999", 4); + let decimal256_2 = decimal_to_scaled_int256("0.0000", 4); + + assert_decimal128_array(&result, 0, vec![decimal128_1, decimal128_2]); + assert_decimal256_array(&result, 1, vec![decimal256_1, decimal256_2]); + } + + #[test] + fn test_invalid_decimal_format() { + let schema = create_test_schema(vec![("decimal_col", "decimal(10,2)")]); + let rows = vec![vec![json!("not-a-number")]]; + + let result = rows_to_arrow(&rows, schema); + assert!(result.is_err()); + } + + #[test] + fn test_binary_data() { + let schema = create_test_schema(vec![("binary_col", "varbinary")]); + + let base64_data = BASE64.encode(b"hello world"); + let rows = vec![vec![json!(base64_data)], vec![json!("plain text")]]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + + assert_binary_array(&result, 0, vec![b"hello world", b"plain text"]); + } + + #[test] + fn test_list_of_strings() { + let schema = create_test_schema(vec![("list_col", "array(varchar)")]); + + let rows = vec![ + vec![json!(["item1", "item2", "item3"])], + vec![json!(["single"])], + vec![Value::Null], + ]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + + assert_list_of_strings_array( + &result, + 0, + vec![ + Some(vec!["item1", "item2", "item3"]), + Some(vec!["single"]), + None, + ], + ); + } + + #[test] + fn test_list_of_integers() { + let schema = create_test_schema(vec![("list_col", "array(integer)")]); + + let rows = vec![ + vec![json!([100, 200, 300])], + vec![json!([-10000000])], + vec![Value::Null], + ]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + + assert_list_of_integers_array( + &result, + 0, + vec![Some(vec![100, 200, 300]), Some(vec![-10000000]), None], + ); + } + + #[test] + fn test_list_of_lists() { + let schema = create_test_schema(vec![("list_col", "array(array(integer))")]); + + let rows = vec![ + vec![json!([[1, 2, 3], [4, 5], [6]])], + vec![json!([[10, 20]])], + vec![json!([])], + vec![Value::Null], + ]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + + assert_list_of_strings_array( + &result, + 0, + vec![ + Some(vec!["[1,2,3]", "[4,5]", "[6]"]), + Some(vec!["[10,20]"]), + Some(vec![]), + None, + ], + ); + } + + #[test] + fn test_list_of_structs() { + let schema = + create_test_schema(vec![("list_col", "array(row(name varchar, age integer))")]); + + let rows = vec![ + vec![json!([{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}])], + vec![json!([{"name": "Charlie", "age": 35}])], + vec![json!([])], + vec![Value::Null], + ]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + + assert_list_of_strings_array( + &result, + 0, + vec![ + Some(vec![ + r#"{"age":30,"name":"Alice"}"#, + r#"{"age":25,"name":"Bob"}"#, + ]), + Some(vec![r#"{"age":35,"name":"Charlie"}"#]), + Some(vec![]), + None, + ], + ); + } + + #[test] + fn test_struct() { + let schema = create_test_schema(vec![("struct_col", "row(name varchar, age integer)")]); + + let rows = vec![ + vec![json!({"name": "Alice", "age": 30})], + vec![json!(["Bob", 25])], + vec![Value::Null], + ]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + + assert_struct_array( + &result, + 0, + vec![ + Some(json!({"name": "Alice", "age": 30})), + Some(json!({"name": "Bob", "age": 25})), + None, + ], + ); + } + + #[test] + fn test_struct_with_list() { + let schema = create_test_schema(vec![( + "struct_col", + "row(name varchar, tags array(varchar))", + )]); + + let rows = vec![ + vec![json!({"name": "Alice", "tags": ["tag1", "tag2", "tag3"]})], + vec![json!({"name": "Bob", "tags": ["single_tag"]})], + vec![json!({"name": "Charlie", "tags": []})], + vec![Value::Null], + ]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + + assert_struct_array( + &result, + 0, + vec![ + Some(json!({"name": "Alice", "tags": "[\"tag1\",\"tag2\",\"tag3\"]"})), + Some(json!({"name": "Bob", "tags": "[\"single_tag\"]"})), + Some(json!({"name": "Charlie", "tags": "[]"})), + None, + ], + ); + } + + #[test] + fn test_struct_with_nested_struct() { + let schema = create_test_schema(vec![( + "struct_col", + "row(name varchar, address row(street varchar, city varchar))", + )]); + + let rows = vec![ + vec![ + json!({"name": "Alice", "address": {"street": "123 Main St", "city": "New York"}}), + ], + vec![json!({"name": "Bob", "address": {"street": "456 Oak Ave", "city": "Boston"}})], + vec![json!({"name": "Charlie", "address": {}})], // empty nested struct + vec![Value::Null], + ]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + + assert_struct_array( + &result, + 0, + vec![ + Some( + json!({"name": "Alice", "address": "{\"city\":\"New York\",\"street\":\"123 Main St\"}"}), + ), + Some( + json!({"name": "Bob", "address": "{\"city\":\"Boston\",\"street\":\"456 Oak Ave\"}"}), + ), + Some(json!({"name": "Charlie", "address": "{}"})), + None, + ], + ); + } + + #[test] + fn test_integer_overflow_handling() { + let schema = create_test_schema(vec![ + ("int8_col", "tinyint"), + ("int16_col", "smallint"), + ("int32_col", "integer"), + ]); + + let rows = vec![vec![ + json!(1000), + json!(100000), + json!(9223372036854775807i64), + ]]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + assert_eq!(result.num_rows(), 1); + + let int8_array = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(int8_array.is_null(0)); + + let int16_array = result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(int16_array.is_null(0)); + + let int32_array = result + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(int32_array.is_null(0)); + } + + #[test] + fn test_type_coercion_fallback() { + let schema = create_test_schema(vec![ + ("bool_col", "boolean"), + ("int_col", "integer"), + ("string_col", "varchar"), + ]); + + let rows = vec![vec![ + json!("not a boolean"), + json!("not a number"), + json!(42), + ]]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + assert_eq!(result.num_rows(), 1); + + let bool_array = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(bool_array.is_null(0)); + + let int_array = result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(int_array.is_null(0)); + + let string_array = result + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(!string_array.is_null(0)); + assert_eq!(string_array.value(0), "42"); + } + + #[test] + fn test_large_dataset() { + let schema = create_test_schema(vec![("id", "bigint"), ("value", "varchar")]); + + let mut rows = Vec::new(); + for i in 0..1000 { + rows.push(vec![json!(i), json!(format!("value_{}", i))]); + } + + let result = rows_to_arrow(&rows, schema).unwrap(); + assert_eq!(result.num_rows(), 1000); + assert_eq!(result.num_columns(), 2); + + let id_array = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id_array.value(0), 0); + assert_eq!(id_array.value(999), 999); + + let value_array = result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(value_array.value(0), "value_0"); + assert_eq!(value_array.value(999), "value_999"); + } + + #[test] + fn test_mixed_null_and_valid_data() { + let schema = + create_test_schema(vec![("mixed_int", "integer"), ("mixed_string", "varchar")]); + + let rows = vec![ + vec![json!(1), json!("first")], + vec![Value::Null, json!("second")], + vec![json!(3), Value::Null], + vec![Value::Null, Value::Null], + vec![json!(5), json!("fifth")], + ]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + assert_eq!(result.num_rows(), 5); + + let int_array = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(!int_array.is_null(0)); + assert!(int_array.is_null(1)); + assert!(!int_array.is_null(2)); + assert!(int_array.is_null(3)); + assert!(!int_array.is_null(4)); + + assert_eq!(int_array.value(0), 1); + assert_eq!(int_array.value(2), 3); + assert_eq!(int_array.value(4), 5); + } + + #[test] + fn test_null_builder_type() { + let schema = create_test_schema(vec![("null_col", "null")]); + let rows = vec![vec![Value::Null], vec![Value::Null]]; + + let result = rows_to_arrow(&rows, schema).unwrap(); + assert_eq!(result.num_rows(), 2); + assert_eq!(result.num_columns(), 1); + + let null_array = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(null_array.len(), 2); + } +} diff --git a/src/sql/arrow_sql_gen/trino/schema.rs b/src/sql/arrow_sql_gen/trino/schema.rs new file mode 100644 index 00000000..03cb9cdc --- /dev/null +++ b/src/sql/arrow_sql_gen/trino/schema.rs @@ -0,0 +1,631 @@ +use super::{Error, RegexSnafu, Result}; +use arrow::datatypes::DataType; +use arrow_schema::{Field, Fields, TimeUnit}; +use regex::Regex; +use snafu::ResultExt; +use std::sync::Arc; + +pub(crate) fn trino_data_type_to_arrow_type( + trino_type: &str, + tz: Option<&str>, +) -> Result { + let normalized_type = trino_type.to_lowercase(); + + if normalized_type.starts_with("time(") { + let time_unit = time_unit_from_precision(extract_precision(&normalized_type, "time")?); + + return match time_unit { + TimeUnit::Second | TimeUnit::Millisecond => Ok(DataType::Time32(TimeUnit::Millisecond)), + time_unit => Ok(DataType::Time64(time_unit)), + }; + } + + if normalized_type.contains("with time zone") && normalized_type.starts_with("timestamp(") { + let time_unit = time_unit_from_precision(extract_precision(&normalized_type, "timestamp")?); + return Ok(DataType::Timestamp( + time_unit, + Some(Arc::from(tz.unwrap_or("UTC"))), + )); + } + + if normalized_type.starts_with("timestamp(") { + let time_unit = time_unit_from_precision(extract_precision(&normalized_type, "timestamp")?); + return Ok(DataType::Timestamp(time_unit, None)); + } + + match normalized_type.as_str() { + "null" => Ok(DataType::Null), + "boolean" => Ok(DataType::Boolean), + "tinyint" => Ok(DataType::Int8), + "smallint" => Ok(DataType::Int16), + "integer" => Ok(DataType::Int32), + "bigint" => Ok(DataType::Int64), + "real" => Ok(DataType::Float32), + "double" => Ok(DataType::Float64), + "varchar" | "char" => Ok(DataType::Utf8), + "varbinary" => Ok(DataType::Binary), + "date" => Ok(DataType::Date32), + _ if normalized_type.starts_with("decimal") || normalized_type.starts_with("numeric") => { + parse_decimal_type(&normalized_type) + } + _ if normalized_type.starts_with("varchar") => Ok(DataType::Utf8), + _ if normalized_type.starts_with("char") => Ok(DataType::Utf8), + _ if normalized_type.starts_with("varbinary") => Ok(DataType::Binary), + _ if normalized_type.starts_with("array") => parse_array_type(&normalized_type, tz), + _ if normalized_type.starts_with("row") => parse_row_type(&normalized_type, tz), + _ => Err(Error::UnsupportedTrinoType { + trino_type: trino_type.to_string(), + }), + } +} + +pub fn extract_precision(s: &str, prefix: &str) -> Result { + let pattern = format!(r"^{}(?:\((\d+)\))?", regex::escape(prefix)); + let re = Regex::new(&pattern).context(RegexSnafu)?; + let caps = re.captures(s).ok_or_else(|| Error::InvalidPrecision { + trino_type: s.to_string(), + })?; + + let precision = match caps.get(1) { + Some(m) => m + .as_str() + .parse::() + .map_err(|_| Error::InvalidPrecision { + trino_type: s.to_string(), + })?, + None => { + return Err(Error::InvalidPrecision { + trino_type: s.to_string(), + }) + } + }; + + Ok(precision) +} + +fn time_unit_from_precision(p: u32) -> TimeUnit { + match p { + 0..=3 => TimeUnit::Millisecond, + 4..=6 => TimeUnit::Microsecond, + _ => TimeUnit::Nanosecond, + } +} + +fn parse_decimal_type(type_str: &str) -> Result { + if let Some(start) = type_str.find('(') { + if let Some(end) = type_str.find(')') { + let params = &type_str[start + 1..end]; + let parts: Vec<&str> = params.split(',').collect(); + + let precision = parts[0].trim().parse::().unwrap_or(38); + let scale = if parts.len() > 1 { + parts[1].trim().parse::().unwrap_or(0) + } else { + 0 + }; + + if precision > 38 { + Ok(DataType::Decimal256(precision, scale)) + } else { + Ok(DataType::Decimal128(precision, scale)) + } + } else { + Ok(DataType::Decimal128(18, 6)) + } + } else { + Ok(DataType::Decimal128(18, 6)) + } +} + +fn parse_array_type(type_str: &str, tz: Option<&str>) -> Result { + if let Some(start) = type_str.find('(') { + if let Some(end) = type_str.rfind(')') { + let element_type_str = &type_str[start + 1..end]; + return match trino_data_type_to_arrow_type(element_type_str, tz)? { + DataType::Struct(_) | DataType::List(_) | DataType::Map(_, _) => Ok( + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + ), + inner_arrow_type => Ok(DataType::List(Arc::new(Field::new( + "item", + inner_arrow_type, + true, + )))), + }; + } + } + + Err(Error::UnsupportedTrinoType { + trino_type: type_str.to_string(), + }) +} + +fn parse_row_type(type_str: &str, tz: Option<&str>) -> Result { + if let Some(start) = type_str.find('(') { + if let Some(end) = type_str.rfind(')') { + let inner = &type_str[start + 1..end]; + let mut fields = Vec::new(); + + let field_definitions = split_respecting_parentheses(inner, ','); + + for field_def in field_definitions { + let field_def = field_def.trim(); + if let Some(space_pos) = field_def.find(' ') { + let field_name = field_def[..space_pos].trim(); + let field_type = field_def[space_pos + 1..].trim(); + let arrow_type = match trino_data_type_to_arrow_type(field_type, tz)? { + DataType::Struct(_) | DataType::List(_) | DataType::Map(_, _) => { + DataType::Utf8 + } + inner_arrow_type => inner_arrow_type, + }; + fields.push(Field::new(field_name, arrow_type, true)); + } + } + + return Ok(DataType::Struct(Fields::from(fields))); + } + } + Err(Error::UnsupportedTrinoType { + trino_type: type_str.to_string(), + }) +} + +fn split_respecting_parentheses(s: &str, delimiter: char) -> Vec { + let mut result = Vec::new(); + let mut current = String::new(); + let mut paren_depth = 0; + + for ch in s.chars() { + match ch { + '(' => { + paren_depth += 1; + current.push(ch); + } + ')' => { + paren_depth -= 1; + current.push(ch); + } + ch if ch == delimiter && paren_depth == 0 => { + result.push(current.trim().to_string()); + current.clear(); + } + _ => { + current.push(ch); + } + } + } + + if !current.is_empty() { + result.push(current.trim().to_string()); + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Fields, TimeUnit}; + use std::sync::Arc; + + #[test] + fn test_basic_types() { + assert_eq!( + trino_data_type_to_arrow_type("null", None).unwrap(), + DataType::Null + ); + assert_eq!( + trino_data_type_to_arrow_type("boolean", None).unwrap(), + DataType::Boolean + ); + assert_eq!( + trino_data_type_to_arrow_type("tinyint", None).unwrap(), + DataType::Int8 + ); + assert_eq!( + trino_data_type_to_arrow_type("smallint", None).unwrap(), + DataType::Int16 + ); + assert_eq!( + trino_data_type_to_arrow_type("integer", None).unwrap(), + DataType::Int32 + ); + assert_eq!( + trino_data_type_to_arrow_type("bigint", None).unwrap(), + DataType::Int64 + ); + assert_eq!( + trino_data_type_to_arrow_type("real", None).unwrap(), + DataType::Float32 + ); + assert_eq!( + trino_data_type_to_arrow_type("double", None).unwrap(), + DataType::Float64 + ); + } + + #[test] + fn test_string_types() { + assert_eq!( + trino_data_type_to_arrow_type("varchar", None).unwrap(), + DataType::Utf8 + ); + assert_eq!( + trino_data_type_to_arrow_type("char", None).unwrap(), + DataType::Utf8 + ); + assert_eq!( + trino_data_type_to_arrow_type("varbinary", None).unwrap(), + DataType::Binary + ); + + assert_eq!( + trino_data_type_to_arrow_type("varchar(255)", None).unwrap(), + DataType::Utf8 + ); + assert_eq!( + trino_data_type_to_arrow_type("char(10)", None).unwrap(), + DataType::Utf8 + ); + assert_eq!( + trino_data_type_to_arrow_type("varbinary(1000)", None).unwrap(), + DataType::Binary + ); + } + + #[test] + fn test_date() { + assert_eq!( + trino_data_type_to_arrow_type("date", None).unwrap(), + DataType::Date32 + ); + } + + #[test] + fn test_time() { + assert_eq!( + trino_data_type_to_arrow_type("time(0)", None).unwrap(), + DataType::Time32(TimeUnit::Millisecond) + ); + assert_eq!( + trino_data_type_to_arrow_type("time(1)", None).unwrap(), + DataType::Time32(TimeUnit::Millisecond) + ); + assert_eq!( + trino_data_type_to_arrow_type("time(2)", None).unwrap(), + DataType::Time32(TimeUnit::Millisecond) + ); + assert_eq!( + trino_data_type_to_arrow_type("time(3)", None).unwrap(), + DataType::Time32(TimeUnit::Millisecond) + ); + + assert_eq!( + trino_data_type_to_arrow_type("time(4)", None).unwrap(), + DataType::Time64(TimeUnit::Microsecond) + ); + assert_eq!( + trino_data_type_to_arrow_type("time(5)", None).unwrap(), + DataType::Time64(TimeUnit::Microsecond) + ); + assert_eq!( + trino_data_type_to_arrow_type("time(6)", None).unwrap(), + DataType::Time64(TimeUnit::Microsecond) + ); + + assert_eq!( + trino_data_type_to_arrow_type("time(7)", None).unwrap(), + DataType::Time64(TimeUnit::Nanosecond) + ); + assert_eq!( + trino_data_type_to_arrow_type("time(8)", None).unwrap(), + DataType::Time64(TimeUnit::Nanosecond) + ); + assert_eq!( + trino_data_type_to_arrow_type("time(9)", None).unwrap(), + DataType::Time64(TimeUnit::Nanosecond) + ); + } + + #[test] + fn test_timestamp() { + assert_eq!( + trino_data_type_to_arrow_type("timestamp(0)", None).unwrap(), + DataType::Timestamp(TimeUnit::Millisecond, None) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(1)", None).unwrap(), + DataType::Timestamp(TimeUnit::Millisecond, None) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(2)", None).unwrap(), + DataType::Timestamp(TimeUnit::Millisecond, None) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(3)", None).unwrap(), + DataType::Timestamp(TimeUnit::Millisecond, None) + ); + + assert_eq!( + trino_data_type_to_arrow_type("timestamp(4)", None).unwrap(), + DataType::Timestamp(TimeUnit::Microsecond, None) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(5)", None).unwrap(), + DataType::Timestamp(TimeUnit::Microsecond, None) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(6)", None).unwrap(), + DataType::Timestamp(TimeUnit::Microsecond, None) + ); + + assert_eq!( + trino_data_type_to_arrow_type("timestamp(7)", None).unwrap(), + DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(8)", None).unwrap(), + DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(9)", None).unwrap(), + DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + } + + #[test] + fn test_timestamp_with_timezone() { + assert_eq!( + trino_data_type_to_arrow_type("timestamp(0) with time zone", None).unwrap(), + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(1) with time zone", None).unwrap(), + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(2) with time zone", None).unwrap(), + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(3) with time zone", None).unwrap(), + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())) + ); + + assert_eq!( + trino_data_type_to_arrow_type("timestamp(4) with time zone", None).unwrap(), + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(5) with time zone", None).unwrap(), + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(6) with time zone", None).unwrap(), + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())) + ); + + assert_eq!( + trino_data_type_to_arrow_type("timestamp(7) with time zone", None).unwrap(), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(8) with time zone", None).unwrap(), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) + ); + assert_eq!( + trino_data_type_to_arrow_type("timestamp(9) with time zone", None).unwrap(), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) + ); + } + + #[test] + fn test_case_insensitive() { + assert_eq!( + trino_data_type_to_arrow_type("BOOLEAN", None).unwrap(), + DataType::Boolean + ); + assert_eq!( + trino_data_type_to_arrow_type("Boolean", None).unwrap(), + DataType::Boolean + ); + assert_eq!( + trino_data_type_to_arrow_type("VARCHAR", None).unwrap(), + DataType::Utf8 + ); + } + + #[test] + fn test_decimal_types() { + assert_eq!( + trino_data_type_to_arrow_type("decimal", None).unwrap(), + DataType::Decimal128(18, 6) + ); + + assert_eq!( + trino_data_type_to_arrow_type("decimal(10)", None).unwrap(), + DataType::Decimal128(10, 0) + ); + + assert_eq!( + trino_data_type_to_arrow_type("decimal(10,2)", None).unwrap(), + DataType::Decimal128(10, 2) + ); + + assert_eq!( + trino_data_type_to_arrow_type("decimal(50,10)", None).unwrap(), + DataType::Decimal256(50, 10) + ); + + assert_eq!( + trino_data_type_to_arrow_type("numeric(10,2)", None).unwrap(), + DataType::Decimal128(10, 2) + ); + + assert_eq!( + trino_data_type_to_arrow_type("decimal(38,0)", None).unwrap(), + DataType::Decimal128(38, 0) + ); + + assert_eq!( + trino_data_type_to_arrow_type("decimal(39,0)", None).unwrap(), + DataType::Decimal256(39, 0) + ); + } + + #[test] + fn test_array_types() { + assert_eq!( + trino_data_type_to_arrow_type("array(integer)", None).unwrap(), + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))) + ); + + assert_eq!( + trino_data_type_to_arrow_type("array(varchar)", None).unwrap(), + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))) + ); + + assert_eq!( + trino_data_type_to_arrow_type("array(decimal(10,2))", None).unwrap(), + DataType::List(Arc::new(Field::new( + "item", + DataType::Decimal128(10, 2), + true + ))) + ); + } + + #[test] + fn test_nested_array_types() { + // Array of array becomes array of strings + assert_eq!( + trino_data_type_to_arrow_type("array(array(integer))", None).unwrap(), + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + ); + + // Array of maps becomes array of strings + assert_eq!( + trino_data_type_to_arrow_type("array(row(name varchar, age integer))", None).unwrap(), + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))) + ); + + // Array of maps is not supported + let res = trino_data_type_to_arrow_type("array(map(varchar, integer))", None); + assert!(res.is_err()); + } + + #[test] + fn test_map_types() { + // Maps are not supported + let res = trino_data_type_to_arrow_type("map(varchar, integer)", None); + assert!(res.is_err()) + } + + #[test] + fn test_row_type_simple() { + let expected = DataType::Struct(Fields::from(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::Int32, true), + ])); + assert_eq!( + trino_data_type_to_arrow_type("row(name varchar, age integer)", None).unwrap(), + expected + ); + } + + fn test_row_type_complex() { + let expected_multi = DataType::Struct(Fields::from(vec![ + Field::new("id", DataType::Int64, true), + Field::new("name", DataType::Utf8, true), + Field::new("salary", DataType::Decimal128(10, 2), true), + Field::new("active", DataType::Boolean, true), + ])); + assert_eq!( + trino_data_type_to_arrow_type( + "row(id bigint, name varchar, salary decimal(10,2), active boolean)", + None + ) + .unwrap(), + expected_multi + ); + } + + #[test] + fn test_row_type_with_array() { + let expected_row_array = DataType::Struct(Fields::from(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("scores", DataType::Utf8, true), + ])); + assert_eq!( + trino_data_type_to_arrow_type("row(name varchar, scores array(integer))", None) + .unwrap(), + expected_row_array + ); + } + + #[test] + fn test_row_type_with_map() { + // row of maps is not supported + let res = + trino_data_type_to_arrow_type("row(name varchar, scores map(varchar, integer))", None); + assert!(res.is_err()); + } + + #[test] + fn test_row_type_with_nested_row() { + let expected_row_array = DataType::Struct(Fields::from(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("scores", DataType::Utf8, true), + ])); + assert_eq!( + trino_data_type_to_arrow_type("row(name varchar, scores row(value integer))", None) + .unwrap(), + expected_row_array + ); + } + + #[test] + fn test_unsupported_types() { + let result = trino_data_type_to_arrow_type("unknown_type", None); + assert!(result.is_err()); + if let Err(Error::UnsupportedTrinoType { trino_type }) = result { + assert_eq!(trino_type, "unknown_type"); + } + } + + #[test] + fn test_extract_precision() { + let result = extract_precision("time(3)", "time"); + assert_eq!(result.unwrap(), 3); + + let result = extract_precision("timestamp(6)", "timestamp"); + assert_eq!(result.unwrap(), 6); + + let result = extract_precision("timestamp(9) with time zone", "timestamp"); + assert_eq!(result.unwrap(), 9); + + let result = extract_precision("time", "time"); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::InvalidPrecision { .. } + )); + + let result = extract_precision("timestamp(x)", "timestamp"); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::InvalidPrecision { .. } + )); + + let result = extract_precision("row(x integer)", "timestamp"); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::InvalidPrecision { .. } + )); + + let result = extract_precision("array(timestamp(6))", "timestamp"); + assert!(result.is_err()); + } +} diff --git a/src/sql/db_connection_pool/dbconnection.rs b/src/sql/db_connection_pool/dbconnection.rs index 939324f2..5c87d74a 100644 --- a/src/sql/db_connection_pool/dbconnection.rs +++ b/src/sql/db_connection_pool/dbconnection.rs @@ -13,6 +13,8 @@ pub mod mysqlconn; pub mod postgresconn; #[cfg(feature = "sqlite")] pub mod sqliteconn; +#[cfg(feature = "trino")] +pub mod trinoconn; pub type GenericError = Box; type Result = std::result::Result; diff --git a/src/sql/db_connection_pool/dbconnection/trinoconn.rs b/src/sql/db_connection_pool/dbconnection/trinoconn.rs new file mode 100644 index 00000000..56d043a3 --- /dev/null +++ b/src/sql/db_connection_pool/dbconnection/trinoconn.rs @@ -0,0 +1,286 @@ +use super::AsyncDbConnection; +use super::DbConnection; +use super::Result; +use crate::sql::arrow_sql_gen::trino::{ + self, arrow::rows_to_arrow, schema::trino_data_type_to_arrow_type, +}; +use crate::UnsupportedTypeAction; +use arrow::datatypes::Field; +use arrow::datatypes::Schema; +use arrow::datatypes::SchemaRef; +use async_stream::stream; +use async_stream::try_stream; +use datafusion::error::DataFusionError; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::sql::TableReference; +use futures::Stream; +use futures::StreamExt; +use serde_json::Value; +use snafu::prelude::*; +use std::pin::Pin; +use std::time::Duration; +use std::{any::Any, sync::Arc}; +use tokio::time::sleep; + +pub type QueryStream = Pin>, Error>> + Send>>; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Query execution failed.\n{source}\nFor details, refer to the Trino documentation: https://trino.io/docs/"))] + QueryError { source: reqwest::Error }, + + #[snafu(display("Failed to convert query result to Arrow.\n{source}\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))] + ConversionError { source: trino::Error }, + + #[snafu(display("Authentication failed."))] + AuthenticationFailedError, + + #[snafu(display("Trino server error: {status_code} - {message}"))] + TrinoServerError { status_code: u16, message: String }, + + #[snafu(display("Unsupported data type '{data_type}' for field '{column_name}'.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))] + UnsupportedDataTypeError { + column_name: String, + data_type: String, + }, + + #[snafu(display("Failed to find the field '{field}'.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))] + MissingFieldError { field: String }, + + #[snafu(display("No schema was provide.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issuesd"))] + NoSchema, +} + +pub const DEFAULT_POLL_WAIT_TIME_MS: u64 = 50; + +pub struct TrinoConnection { + client: Arc, + base_url: String, + unsupported_type_action: UnsupportedTypeAction, + poll_wait_time: Duration, + tz: Option, +} + +impl<'a> DbConnection, &'a str> for TrinoConnection { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn as_async(&self) -> Option<&dyn AsyncDbConnection, &'a str>> { + Some(self) + } +} + +#[async_trait::async_trait] +impl<'a> AsyncDbConnection, &'a str> for TrinoConnection { + fn new(client: Arc) -> Self { + TrinoConnection { + client, + base_url: String::new(), + unsupported_type_action: UnsupportedTypeAction::default(), + poll_wait_time: Duration::from_millis(DEFAULT_POLL_WAIT_TIME_MS), + tz: None, + } + } + + async fn get_schema( + &self, + table_reference: &TableReference, + ) -> Result { + let sql = format!("DESCRIBE {table_reference}"); + let mut query_stream = self.execute_query(&sql); + + let mut fields = Vec::new(); + + while let Some(batch_data) = query_stream.next().await { + let batch_data = batch_data.map_err(|e| super::Error::UnableToGetSchema { + source: Box::new(e), + })?; + + for row_data in batch_data { + if row_data.len() >= 2 { + let column_name = + row_data[0] + .as_str() + .ok_or_else(|| super::Error::UnableToGetSchema { + source: Box::new(Error::MissingFieldError { + field: "column_name".to_string(), + }), + })?; + + let data_type = + row_data[1] + .as_str() + .ok_or_else(|| super::Error::UnableToGetSchema { + source: Box::new(Error::MissingFieldError { + field: "data_type".to_string(), + }), + })?; + + let nullable = if row_data.len() > 2 { + row_data[2].as_str().unwrap_or("true") != "false" + } else { + true + }; + + let Ok(arrow_type) = + trino_data_type_to_arrow_type(data_type, self.tz.clone().as_deref()) + else { + return Err(super::Error::UnsupportedDataType { + data_type: data_type.to_string(), + field_name: column_name.to_string(), + }); + }; + + fields.push(Field::new(column_name, arrow_type, nullable)); + } + } + } + + let schema = Arc::new(Schema::new(fields)); + Ok(schema) + } + + async fn query_arrow( + &self, + sql: &str, + _params: &[&'a str], + projected_schema: Option, + ) -> Result { + let schema_ref = projected_schema.ok_or(Error::NoSchema)?; + + let mut query_stream = self.execute_query(sql); + + let schema_for_stream = Arc::clone(&schema_ref); + + let mut arrow_stream = Box::pin(stream! { + while let Some(batch_data) = query_stream.next().await { + let batch_data = batch_data.map_err(|e| super::Error::UnableToQueryArrow { + source: Box::new(e), + })?; + + if !batch_data.is_empty() { + let chunk_size = 4_000; + for chunk in batch_data.chunks(chunk_size) { + let rec = rows_to_arrow(chunk, Arc::clone(&schema_for_stream)) + .map_err(|e| super::Error::UnableToQueryArrow { + source: Box::new(Error::ConversionError { source: e }), + })?; + yield Ok::<_, super::Error>(rec); + } + } + } + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new(schema_ref, { + stream! { + while let Some(batch) = arrow_stream.next().await { + yield batch + .map_err(|e| DataFusionError::Execution(format!("Failed to fetch batch: {e}"))) + } + } + }))) + } + + async fn execute(&self, _query: &str, _params: &[&'a str]) -> Result { + unimplemented!("Execute not implemented for Trino"); + } +} + +impl TrinoConnection { + pub fn new_with_config( + client: Arc, + base_url: String, + poll_wait_time: Duration, + tz: Option, + ) -> Self { + TrinoConnection { + client, + base_url, + unsupported_type_action: UnsupportedTypeAction::default(), + poll_wait_time, + tz, + } + } + + #[must_use] + pub fn with_unsupported_type_action(mut self, action: UnsupportedTypeAction) -> Self { + self.unsupported_type_action = action; + self + } + + fn execute_query(&self, sql: &str) -> QueryStream { + let client = self.client.clone(); + let url = format!("{}/v1/statement", self.base_url); + let poll_wait_time = self.poll_wait_time; + let sql = sql.to_string(); + + Box::pin(try_stream! { + let mut result: Value = client + .post(&url) + .body(sql) + .send() + .await + .context(QuerySnafu)? + .json() + .await + .context(QuerySnafu)?; + + loop { + let state = result["stats"]["state"].as_str().unwrap_or(""); + + if state == "FAILED" { + Err(Error::TrinoServerError { + status_code: 500, + message: "Query failed".to_string(), + })?; + } else if state == "CANCELED" { + Err(Error::TrinoServerError { + status_code: 499, + message: "Query was canceled".to_string(), + })?; + } + + if let Some(data) = result.get("data").and_then(|d| d.as_array()) { + let batch_data = data.iter() + .filter_map(|row| row.as_array().cloned()) + .collect::>(); + + if !batch_data.is_empty() { + yield batch_data; + } + } + + if state == "FINISHED" { + break; + } + + if let Some(next_uri) = result.get("nextUri").and_then(|u| u.as_str()) { + sleep(poll_wait_time).await; + + result = client + .get(next_uri) + .send() + .await + .context(QuerySnafu)? + .json() + .await + .context(QuerySnafu)?; + } else { + if state != "FINISHED" { + Err(Error::TrinoServerError { + status_code: 500, + message: format!("Query stuck in state: {state}"), + })?; + } + break; + } + } + }) + } +} diff --git a/src/sql/db_connection_pool/mod.rs b/src/sql/db_connection_pool/mod.rs index ad534f82..932ddfec 100644 --- a/src/sql/db_connection_pool/mod.rs +++ b/src/sql/db_connection_pool/mod.rs @@ -11,6 +11,8 @@ pub mod mysqlpool; pub mod postgrespool; #[cfg(feature = "sqlite")] pub mod sqlitepool; +#[cfg(feature = "trino")] +pub mod trinodbpool; pub type Error = Box; type Result = std::result::Result; diff --git a/src/sql/db_connection_pool/trinodbpool.rs b/src/sql/db_connection_pool/trinodbpool.rs new file mode 100644 index 00000000..14b4386b --- /dev/null +++ b/src/sql/db_connection_pool/trinodbpool.rs @@ -0,0 +1,872 @@ +use super::DbConnectionPool; +use crate::sql::db_connection_pool::dbconnection::trinoconn::DEFAULT_POLL_WAIT_TIME_MS; +use crate::{ + sql::db_connection_pool::{ + dbconnection::{trinoconn::TrinoConnection, DbConnection}, + JoinPushDown, + }, + util::{self, ns_lookup::verify_ns_lookup_and_tcp_connect}, + UnsupportedTypeAction, +}; +use async_trait::async_trait; +use base64::engine::general_purpose::STANDARD as BASE64; +use base64::Engine; +use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION}; +use reqwest::{Certificate, Client, ClientBuilder, Identity}; +use secrecy::{ExposeSecret, SecretString}; +use snafu::{ResultExt, Snafu}; +use std::path::PathBuf; +use std::{collections::HashMap, fs, sync::Arc, time::Duration}; + +pub type Result = std::result::Result; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Trino connection failed.\n{source}\nFor details, refer to the Trino documentation: https://trino.io/docs/"))] + TrinoConnectionError { source: reqwest::Error }, + + #[snafu(display("Could not parse {parameter_name} into a valid integer. Ensure it is configured with a valid value."))] + InvalidIntegerParameterError { + parameter_name: String, + source: std::num::ParseIntError, + }, + + #[snafu(display("Cannot connect to Trino on {host}:{port}. Ensure the host and port are correct and reachable."))] + InvalidHostOrPortError { + source: util::ns_lookup::Error, + host: String, + port: u16, + }, + + #[snafu(display("Authentication failed."))] + AuthenticationFailedError, + + #[snafu(display("Invalid Trino URL: {url}. Ensure it starts with http:// or https://"))] + InvalidTrinoUrl { url: String }, + + #[snafu(display( + "Invalid sslmode: {value}. Expected values are: required, preferred, disabled" + ))] + InvalidSSLModeParameter { value: String }, + + #[snafu(display("Missing required parameter: {parameter_name}"))] + MissingRequiredParameter { parameter_name: String }, + + #[snafu(display("Failed to build HTTP client: {source}"))] + FailedToBuildTrinoHttpClient { source: reqwest::Error }, + + #[snafu(display("Trino server error: {status_code} - {message}"))] + TrinoServerError { status_code: u16, message: String }, + + #[snafu(display("Invalid Trino authentication configuration: {details}"))] + InvalidAuthConfig { details: String }, + + #[snafu(display("Failed to read identity PEM file at '{}': {}", path, source))] + UnableToReadIdentityPem { + path: String, + source: std::io::Error, + }, + + #[snafu(display("Invalid identity PEM at '{}': {}", path, source))] + InvalidIdentityPem { + path: String, + source: reqwest::Error, + }, + + #[snafu(display("Failed to read root cert file at '{path}': {source}"))] + UnableToReadRootCert { + path: String, + source: std::io::Error, + }, + + #[snafu(display("Invalid root cert at '{path}': {source}"))] + InvalidRootCert { + path: String, + source: reqwest::Error, + }, +} + +const DEFAULT_TIMEOUT_MS: u64 = 30_000; +const DEFAULT_SSL_MODE: &str = "required"; + +#[derive(Clone)] +pub struct TrinoConnectionPool { + base_url: String, + catalog: String, + schema: String, + client: Arc, + join_push_down: JoinPushDown, + unsupported_type_action: UnsupportedTypeAction, + poll_wait_time: Duration, + tz: Option, +} + +impl std::fmt::Debug for TrinoConnectionPool { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TrinoConnectionPool") + .field("base_url", &self.base_url) + .field("catalog", &self.catalog) + .field("schema", &self.schema) + .field("join_push_down", &self.join_push_down) + .field("unsupported_type_action", &self.unsupported_type_action) + .finish() + } +} + +impl TrinoConnectionPool { + /// Creates a new instance of `TrinoConnectionPool`. + /// + /// # Arguments + /// + /// * `params` - A map of parameters to create the connection pool. + /// * `host` - The Trino coordinator host (required) + /// * `port` - The Trino coordinator port (optional, defaults to 8080) + /// * `catalog` - The default catalog to use (required) + /// * `schema` - The default schema to use (optional, defaults to "default") + /// * `user` - The user to authenticate with (required) + /// * `password` - The password for authentication (optional) + /// * `timeout_ms` - Request timeout in ms (optional, defaults to 300) + /// * `sslmode` - TLS/SSL mode for the connection. Supported values: 'disabled', 'required', 'preferred'. Defaults to 'required'. 'preferred' allows invalid certificates/hostnames. + /// * `identity_pem_path` - Path to a PEM file containing both the client certificate and private key for mTLS authentication. (optional) + /// * `bearer_token` - Bearer token for authentication (optional) + /// * `poll_wait_time_ms` - Waiting time in ms between polling trino results (optional, defaults to 50) + /// * `time_zone` - The time zone to use for the MySQL connection (e.g., "+2:00", "UTC", etc.). Default is "+00:00" (UTC). + /// + /// # Errors + /// + /// Returns an error if there is a problem creating the connection pool. + pub async fn new(params: HashMap) -> Result { + let params = util::remove_prefix_from_hashmap_keys(params, "trino_"); + + let (catalog, schema) = get_catalog_and_schema(¶ms)?; + let (user, password) = get_user_and_password(¶ms); + let bearer_token = params.get("bearer_token").cloned(); + + validate_auth(¶ms, &user, &password)?; + + let headers = build_headers(&catalog, &schema, &user, &password, &bearer_token)?; + + let timeout_ms = parse_u64_param(¶ms, "timeout_ms", DEFAULT_TIMEOUT_MS)?; + let poll_wait_time = + parse_u64_param(¶ms, "poll_wait_time_ms", DEFAULT_POLL_WAIT_TIME_MS)?; + + let client_builder = Client::builder() + .default_headers(headers) + .timeout(Duration::from_millis(timeout_ms)); + + let (mut client_builder, protocol) = configure_tls(client_builder, ¶ms)?; + + if let Some(identity_path) = params.get("identity_pem_path") { + let pem = + fs::read(identity_path.expose_secret()).context(UnableToReadIdentityPemSnafu { + path: identity_path.expose_secret().to_string(), + })?; + + let identity = Identity::from_pem(&pem).context(InvalidIdentityPemSnafu { + path: identity_path.expose_secret().to_string(), + })?; + client_builder = client_builder.identity(identity); + } + + let base_url = build_base_url(protocol, ¶ms)?; + + let client = client_builder + .build() + .context(FailedToBuildTrinoHttpClientSnafu)?; + + Self::test_connection(&client, &base_url).await?; + + let join_push_down = Self::get_join_context(&base_url, &catalog, &schema, &user); + + Ok(Self { + base_url, + catalog, + schema, + client: Arc::new(client), + join_push_down, + unsupported_type_action: UnsupportedTypeAction::default(), + poll_wait_time: Duration::from_millis(poll_wait_time), + tz: params + .get("time_zone") + .map(|t| t.expose_secret().to_string()), + }) + } + + #[must_use] + pub fn with_unsupported_type_action(mut self, action: UnsupportedTypeAction) -> Self { + self.unsupported_type_action = action; + self + } + + async fn test_connection(client: &Client, base_url: &str) -> Result<()> { + let url = format!("{base_url}/v1/info"); + + let response = client + .get(&url) + .send() + .await + .context(TrinoConnectionSnafu)?; + + if response.status() == 401 { + return Err(Error::AuthenticationFailedError); + } + + if !response.status().is_success() { + return Err(Error::TrinoServerError { + status_code: response.status().as_u16(), + message: format!("Connection test failed with HTTP {}", response.status()), + }); + } + + Ok(()) + } + + fn get_join_context( + base_url: &str, + catalog: &str, + schema: &str, + user: &Option, + ) -> JoinPushDown { + let mut join_context = format!("url={base_url},catalog={catalog},schema={schema}"); + if let Some(user) = user { + join_context.push_str(&format!(",user={user}")); + } + + JoinPushDown::AllowedFor(join_context) + } +} + +#[async_trait] +impl DbConnectionPool, &'static str> for TrinoConnectionPool { + async fn connect(&self) -> super::Result, &'static str>>> { + let connection = TrinoConnection::new_with_config( + self.client.clone(), + self.base_url.clone(), + self.poll_wait_time, + self.tz.clone(), + ) + .with_unsupported_type_action(self.unsupported_type_action); + + Ok(Box::new(connection)) + } + + fn join_push_down(&self) -> JoinPushDown { + self.join_push_down.clone() + } +} + +fn build_base_url(protocol: String, params: &HashMap) -> Result { + let host = params + .get("host") + .map(ExposeSecret::expose_secret) + .ok_or_else(|| Error::MissingRequiredParameter { + parameter_name: "host".to_string(), + })?; + + let port = parse_u16_param(params, "port", 8080)?; + futures::executor::block_on(verify_ns_lookup_and_tcp_connect(host, port)) + .context(InvalidHostOrPortSnafu { host, port })?; + + Ok(format!("{protocol}://{host}:{port}")) +} + +fn get_catalog_and_schema(params: &HashMap) -> Result<(String, String)> { + let catalog = params + .get("catalog") + .map(ExposeSecret::expose_secret) + .ok_or_else(|| Error::MissingRequiredParameter { + parameter_name: "catalog".to_string(), + })? + .to_string(); + + let schema = params + .get("schema") + .map(ExposeSecret::expose_secret) + .unwrap_or("default") + .to_string(); + + Ok((catalog, schema)) +} + +fn get_user_and_password( + params: &HashMap, +) -> (Option, Option) { + let user = params.get("user").map(|u| u.expose_secret().to_string()); + let password = params.get("password").cloned(); + (user, password) +} + +fn validate_auth( + params: &HashMap, + user: &Option, + password: &Option, +) -> Result<()> { + let has_user = user.is_some(); + let has_user_pass = user.is_some() && password.is_some(); + let has_identity = params.contains_key("identity_pem_path"); + let has_token = params.contains_key("bearer_token"); + + if !has_user { + return Err(Error::InvalidAuthConfig { + details: "User is required".into(), + }); + } + + let auth_count = [has_user_pass, has_identity, has_token] + .into_iter() + .filter(|x| *x) + .count(); + + if auth_count > 1 { + return Err(Error::InvalidAuthConfig { + details: "At most one authentication method must be provided: basic auth, mTLS, or bearer token".into(), + }); + } + Ok(()) +} + +fn configure_tls( + client_builder: ClientBuilder, + params: &HashMap, +) -> Result<(ClientBuilder, String)> { + let ssl_mode = params + .get("sslmode") + .map(ExposeSecret::expose_secret) + .unwrap_or(DEFAULT_SSL_MODE) + .to_string() + .to_lowercase(); + + if ssl_mode == "disabled" { + return Ok((client_builder, "http".to_string())); + } + + match ssl_mode.as_str() { + "disabled" | "required" | "preferred" => {} + _ => { + return Err(Error::InvalidSSLModeParameter { + value: ssl_mode.to_string(), + }); + } + } + + let ssl_rootcert = if let Some(cert_path) = params.get("sslrootcert") { + let path = PathBuf::from(cert_path.expose_secret()); + let ca_cert = fs::read(path).context(UnableToReadRootCertSnafu { + path: cert_path.expose_secret().to_string(), + })?; + Some( + Certificate::from_pem(&ca_cert).context(InvalidRootCertSnafu { + path: cert_path.expose_secret().to_string(), + })?, + ) + } else { + None + }; + + let client_builder = match (ssl_rootcert, ssl_mode.as_str()) { + // Root cert + preferred + (Some(cert), "preferred") => client_builder + .add_root_certificate(cert) + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true), + + // Root cert + required + (Some(cert), _) => client_builder.add_root_certificate(cert), + + // No root cert + preferred + (None, "preferred") => client_builder + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true), + + // No root cert + required + (None, _) => client_builder, + }; + + Ok((client_builder, "https".to_string())) +} + +fn build_headers( + catalog: &str, + schema: &str, + user: &Option, + password: &Option, + bearer_token: &Option, +) -> Result { + let mut headers = HeaderMap::new(); + headers.insert("X-Trino-Catalog", catalog.parse().unwrap()); + headers.insert("X-Trino-Schema", schema.parse().unwrap()); + + if let Some(user) = user { + headers.insert("X-Trino-User", user.parse().unwrap()); + } + + if let (Some(user), Some(password)) = (user, password) { + let credentials = format!("{}:{}", user, password.expose_secret()); + let encoded = BASE64.encode(credentials); + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Basic {encoded}")).unwrap(), + ); + } else if let Some(token) = bearer_token { + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {}", token.expose_secret())).unwrap(), + ); + } + + Ok(headers) +} + +fn parse_u64_param(params: &HashMap, key: &str, default: u64) -> Result { + params + .get(key) + .map(ExposeSecret::expose_secret) + .unwrap_or(&default.to_string()) + .parse::() + .context(InvalidIntegerParameterSnafu { + parameter_name: key, + }) +} + +fn parse_u16_param(params: &HashMap, key: &str, default: u16) -> Result { + params + .get(key) + .map(ExposeSecret::expose_secret) + .unwrap_or(&default.to_string()) + .parse::() + .context(InvalidIntegerParameterSnafu { + parameter_name: key, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use mockito::Server; + use secrecy::SecretString; + use std::collections::HashMap; + + use tempfile::NamedTempFile; + + fn create_basic_params() -> HashMap { + let mut params = HashMap::new(); + params.insert( + "catalog".to_string(), + SecretString::new("test_catalog".into()), + ); + params.insert( + "schema".to_string(), + SecretString::new("test_schema".into()), + ); + params + } + + fn create_mock_pem_file() -> NamedTempFile { + let pem_content = r#"-----BEGIN CERTIFICATE----- +MIICljCCAX4CCQCKLy2PtfxYqjANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJV +UzAeFw0yMzEwMDEwMDAwMDBaFw0yNDA5MzAyMzU5NTlaMA0xCzAJBgNVBAYTAlVT +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyJ3yfgDHc... +-----END CERTIFICATE----- +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDInfJ+AMdz... +-----END PRIVATE KEY-----"#; + + let mut file = NamedTempFile::new().expect("Failed to create temp file"); + std::io::Write::write_all(&mut file, pem_content.as_bytes()) + .expect("Failed to write to temp file"); + file + } + + #[tokio::test] + async fn test_new_with_url_basic_auth() { + let mut server = Server::new_async().await; + let mock = server + .mock("GET", "/v1/info") + .with_status(200) + .with_body(r#"{"nodeVersion":{"version":"1.0"}}"#) + .create_async() + .await; + + let mut params = create_basic_params(); + params.insert( + "host".to_string(), + SecretString::new(server.socket_address().ip().to_string().into()), + ); + params.insert( + "port".to_string(), + SecretString::new(server.socket_address().port().to_string().into()), + ); + params.insert("sslmode".to_string(), SecretString::new("disabled".into())); + params.insert("user".to_string(), SecretString::new("testuser".into())); + params.insert("password".to_string(), SecretString::new("testpass".into())); + + let pool = TrinoConnectionPool::new(params).await; + assert!(pool.is_ok()); + + let pool = pool.unwrap(); + assert_eq!(pool.base_url, server.url()); + assert_eq!(pool.catalog, "test_catalog"); + assert_eq!(pool.schema, "test_schema"); + + mock.assert_async().await; + } + + #[tokio::test] + async fn test_new_with_bearer_token() { + let mut server = Server::new_async().await; + let mock = server + .mock("GET", "/v1/info") + .with_status(200) + .with_header("Authorization", "Bearer test-token-123") + .with_body(r#"{"nodeVersion":{"version":"1.0"}}"#) + .create_async() + .await; + + let mut params = create_basic_params(); + params.insert( + "host".to_string(), + SecretString::new(server.socket_address().ip().to_string().into()), + ); + params.insert( + "port".to_string(), + SecretString::new(server.socket_address().port().to_string().into()), + ); + params.insert("sslmode".to_string(), SecretString::new("disabled".into())); + params.insert( + "bearer_token".to_string(), + SecretString::new("test-token-123".into()), + ); + params.insert("user".to_string(), SecretString::new("testuser".into())); + + let pool = TrinoConnectionPool::new(params).await; + assert!(pool.is_ok()); + + mock.assert_async().await; + } + + #[tokio::test] + async fn test_new_with_host_port() { + let mut server = Server::new_async().await; + let url = server.url(); + let url_parts: Vec<&str> = url.split(':').collect(); + let host = url_parts[1].trim_start_matches("//"); + let port: u16 = url_parts[2].parse().unwrap(); + + let mock = server + .mock("GET", "/v1/info") + .with_status(200) + .with_body(r#"{"nodeVersion":{"version":"1.0"}}"#) + .create_async() + .await; + + let mut params = create_basic_params(); + params.insert( + "host".to_string(), + SecretString::new(server.socket_address().ip().to_string().into()), + ); + params.insert("sslmode".to_string(), SecretString::new("disabled".into())); + params.insert( + "port".to_string(), + SecretString::new(port.to_string().into()), + ); + params.insert("user".to_string(), SecretString::new("testuser".into())); + + let pool = TrinoConnectionPool::new(params) + .await + .expect("Failed to create TrinoConnectionPool"); + + mock.assert_async().await; + } + + #[tokio::test] + async fn test_new_missing_catalog() { + let mut params = HashMap::new(); + params.insert("host".to_string(), SecretString::new("localhost".into())); + params.insert("sslmode".to_string(), SecretString::new("disabled".into())); + + let result = TrinoConnectionPool::new(params).await; + assert!(result.is_err()); + + if let Err(Error::MissingRequiredParameter { parameter_name }) = result { + assert_eq!(parameter_name, "catalog"); + } else { + panic!("Expected MissingRequiredParameter error for catalog"); + } + } + + #[tokio::test] + async fn test_new_missing_host() { + let mut params = create_basic_params(); + params.insert("user".to_string(), SecretString::new("testuser".into())); + + let result = TrinoConnectionPool::new(params).await; + assert!(result.is_err()); + + if let Err(Error::MissingRequiredParameter { parameter_name }) = result { + assert_eq!(parameter_name, "host"); + } else { + panic!("Expected MissingRequiredParameter error for url or host"); + } + } + + #[tokio::test] + async fn test_authentication_failed() { + let mut server = Server::new_async().await; + let mock = server + .mock("GET", "/v1/info") + .with_status(401) + .create_async() + .await; + + let mut params = create_basic_params(); + params.insert( + "host".to_string(), + SecretString::new(server.socket_address().ip().to_string().into()), + ); + params.insert( + "port".to_string(), + SecretString::new(server.socket_address().port().to_string().into()), + ); + params.insert("sslmode".to_string(), SecretString::new("disabled".into())); + params.insert("user".to_string(), SecretString::new("baduser".into())); + params.insert("password".to_string(), SecretString::new("badpass".into())); + + let result = TrinoConnectionPool::new(params).await; + assert!(result.is_err()); + + if let Err(Error::AuthenticationFailedError) = result { + // Expected + } else { + panic!("Expected AuthenticationFailedError"); + } + + mock.assert_async().await; + } + + #[tokio::test] + async fn test_server_error() { + let mut server = Server::new_async().await; + let mock = server + .mock("GET", "/v1/info") + .with_status(500) + .with_body("Internal Server Error") + .create_async() + .await; + + let mut params = create_basic_params(); + params.insert( + "host".to_string(), + SecretString::new(server.socket_address().ip().to_string().into()), + ); + params.insert( + "port".to_string(), + SecretString::new(server.socket_address().port().to_string().into()), + ); + params.insert("sslmode".to_string(), SecretString::new("disabled".into())); + params.insert("user".to_string(), SecretString::new("testuser".into())); + + let result = TrinoConnectionPool::new(params).await; + assert!(result.is_err()); + + if let Err(Error::TrinoServerError { status_code, .. }) = result { + assert_eq!(status_code, 500); + } else { + panic!("Expected TrinoServerError"); + } + + mock.assert_async().await; + } + + #[tokio::test] + async fn test_multiple_auth_methods_error() { + let mut params = create_basic_params(); + params.insert( + "url".to_string(), + SecretString::new("http://localhost:8080".into()), + ); + params.insert("host".to_string(), SecretString::new("localhost".into())); + params.insert("sslmode".to_string(), SecretString::new("disabled".into())); + params.insert("user".to_string(), SecretString::new("testuser".into())); + params.insert("password".to_string(), SecretString::new("testpass".into())); + params.insert( + "bearer_token".to_string(), + SecretString::new("token123".into()), + ); + + let result = TrinoConnectionPool::new(params).await; + assert!(result.is_err()); + + if let Err(Error::InvalidAuthConfig { details }) = result { + assert!(details.contains("At most one authentication method")); + } else { + panic!("Expected InvalidAuthConfig error"); + } + } + + #[tokio::test] + async fn test_no_auth_method_allowed() { + let mut server = Server::new_async().await; + let mock = server + .mock("GET", "/v1/info") + .with_status(200) + .with_body(r#"{"nodeVersion":{"version":"1.0"}}"#) + .create_async() + .await; + + let mut params = create_basic_params(); + params.insert( + "host".to_string(), + SecretString::new(server.socket_address().ip().to_string().into()), + ); + params.insert( + "port".to_string(), + SecretString::new(server.socket_address().port().to_string().into()), + ); + params.insert("sslmode".to_string(), SecretString::new("disabled".into())); + params.insert("user".to_string(), SecretString::new("testuser".into())); + + let result = TrinoConnectionPool::new(params).await; + assert!(result.is_ok()); + + mock.assert_async().await; + } + + #[test] + fn test_build_headers_basic_auth() { + let user = Some("testuser".to_string()); + let password = Some(SecretString::new("testpass".into())); + let bearer_token = None; + + let headers = build_headers("catalog", "schema", &user, &password, &bearer_token).unwrap(); + + assert_eq!(headers.get("X-Trino-Catalog").unwrap(), "catalog"); + assert_eq!(headers.get("X-Trino-Schema").unwrap(), "schema"); + assert_eq!(headers.get("X-Trino-User").unwrap(), "testuser"); + + let auth_header = headers.get("Authorization").unwrap().to_str().unwrap(); + assert!(auth_header.starts_with("Basic ")); + + // Decode and verify the basic auth + let encoded = auth_header.strip_prefix("Basic ").unwrap(); + let decoded = String::from_utf8(BASE64.decode(encoded).unwrap()).unwrap(); + assert_eq!(decoded, "testuser:testpass"); + } + + #[test] + fn test_build_headers_bearer_token() { + let user = None; + let password = None; + let bearer_token = Some(SecretString::new("test-token-123".into())); + + let headers = build_headers("catalog", "schema", &user, &password, &bearer_token).unwrap(); + + assert_eq!(headers.get("X-Trino-Catalog").unwrap(), "catalog"); + assert_eq!(headers.get("X-Trino-Schema").unwrap(), "schema"); + assert!(headers.get("X-Trino-User").is_none()); + + let auth_header = headers.get("Authorization").unwrap().to_str().unwrap(); + assert_eq!(auth_header, "Bearer test-token-123"); + } + + #[test] + fn test_build_headers_no_auth() { + let user = None; + let password = None; + let bearer_token = None; + + let headers = build_headers("catalog", "schema", &user, &password, &bearer_token).unwrap(); + + assert_eq!(headers.get("X-Trino-Catalog").unwrap(), "catalog"); + assert_eq!(headers.get("X-Trino-Schema").unwrap(), "schema"); + assert!(headers.get("X-Trino-User").is_none()); + assert!(headers.get("Authorization").is_none()); + } + + #[test] + fn test_parse_parameters() { + let mut params = HashMap::new(); + params.insert("timeout".to_string(), SecretString::new("120".into())); + params.insert("port".to_string(), SecretString::new("9080".into())); + params.insert( + "ssl_verification".to_string(), + SecretString::new("false".into()), + ); + + assert_eq!(parse_u64_param(¶ms, "timeout", 300).unwrap(), 120); + assert_eq!(parse_u16_param(¶ms, "port", 8080).unwrap(), 9080); + + // Test defaults + assert_eq!(parse_u64_param(¶ms, "nonexistent", 300).unwrap(), 300); + assert_eq!(parse_u16_param(¶ms, "nonexistent", 8080).unwrap(), 8080); + } + + #[test] + fn test_parse_invalid_parameters() { + let mut params = HashMap::new(); + params.insert("timeout".to_string(), SecretString::new("invalid".into())); + params.insert("port".to_string(), SecretString::new("99999".into())); // Too large for u16 + + assert!(parse_u64_param(¶ms, "timeout", 300).is_err()); + assert!(parse_u16_param(¶ms, "port", 8080).is_err()); + } + + #[test] + fn test_validate_auth() { + // Test valid cases + let mut params = HashMap::new(); + let user = Some("user".to_string()); + let password = Some(SecretString::new("pass".into())); + assert!(validate_auth(¶ms, &user, &password).is_ok()); + + let password = None; + params.insert( + "bearer_token".to_string(), + SecretString::new("token".into()), + ); + assert!(validate_auth(¶ms, &user, &password).is_ok()); + + // Test invalid case - multiple auth methods + let user = Some("user".to_string()); + let password = Some(SecretString::new("pass".into())); + assert!(validate_auth(¶ms, &user, &password).is_err()); + + // User is required + let user = None; + let password = None; + params.insert( + "bearer_token".to_string(), + SecretString::new("token".into()), + ); + assert!(validate_auth(¶ms, &user, &password).is_err()); + } + + #[test] + fn test_get_catalog_and_schema() { + let mut params = HashMap::new(); + params.insert("catalog".to_string(), SecretString::new("test_cat".into())); + params.insert( + "schema".to_string(), + SecretString::new("test_schema".into()), + ); + + let (catalog, schema) = get_catalog_and_schema(¶ms).unwrap(); + assert_eq!(catalog, "test_cat"); + assert_eq!(schema, "test_schema"); + + params.remove("schema"); + let (catalog, schema) = get_catalog_and_schema(¶ms).unwrap(); + assert_eq!(catalog, "test_cat"); + assert_eq!(schema, "default"); + } + + #[test] + fn test_get_user_and_password() { + let mut params = HashMap::new(); + params.insert("user".to_string(), SecretString::new("testuser".into())); + params.insert("password".to_string(), SecretString::new("testpass".into())); + + let (user, password) = get_user_and_password(¶ms); + assert_eq!(user, Some("testuser".to_string())); + assert!(password.is_some()); + } +} diff --git a/src/sql/sql_provider_datafusion/expr.rs b/src/sql/sql_provider_datafusion/expr.rs index b3cba714..3d5ab945 100644 --- a/src/sql/sql_provider_datafusion/expr.rs +++ b/src/sql/sql_provider_datafusion/expr.rs @@ -30,6 +30,7 @@ pub enum Engine { ODBC, Postgres, MySQL, + Trino, } impl Engine { @@ -39,7 +40,9 @@ impl Engine { Engine::SQLite => Arc::new(SqliteDialect {}), Engine::Postgres => Arc::new(PostgreSqlDialect {}), Engine::MySQL => Arc::new(MySqlDialect {}), - Engine::Spark | Engine::DuckDB | Engine::ODBC => Arc::new(DefaultDialect {}), + Engine::Spark | Engine::DuckDB | Engine::ODBC | Engine::Trino => { + Arc::new(DefaultDialect {}) + } } } } diff --git a/src/trino.rs b/src/trino.rs new file mode 100644 index 00000000..207b963e --- /dev/null +++ b/src/trino.rs @@ -0,0 +1,47 @@ +/* +Copyright 2024 The Spice.ai OSS Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use crate::sql::db_connection_pool::trinodbpool::TrinoConnectionPool; +use datafusion::{datasource::TableProvider, sql::TableReference}; +use sql_table::TrinoTable; +use std::sync::Arc; + +pub mod sql_table; + +pub struct TrinoTableFactory { + pool: Arc, +} + +impl TrinoTableFactory { + #[must_use] + pub fn new(pool: Arc) -> Self { + Self { pool } + } + + pub async fn table_provider( + &self, + table_reference: TableReference, + ) -> Result, Box> { + let pool = Arc::clone(&self.pool); + let table_provider = Arc::new( + TrinoTable::new(&pool, table_reference) + .await + .map_err(|e| Box::new(e) as Box)?, + ); + + Ok(table_provider) + } +} diff --git a/src/trino/sql_table.rs b/src/trino/sql_table.rs new file mode 100644 index 00000000..a06aa930 --- /dev/null +++ b/src/trino/sql_table.rs @@ -0,0 +1,202 @@ +use crate::sql::db_connection_pool::trinodbpool::TrinoConnectionPool; +use crate::sql::db_connection_pool::DbConnectionPool; +use crate::sql::sql_provider_datafusion::expr::Engine; +use async_trait::async_trait; +use datafusion::catalog::Session; +use datafusion::sql::unparser::dialect::DefaultDialect; +use futures::TryStreamExt; +use reqwest::Client; +use std::fmt::Display; +use std::{any::Any, fmt, sync::Arc}; + +use crate::sql::sql_provider_datafusion::{ + self, get_stream, to_execution_error, Result as SqlResult, SqlExec, SqlTable, +}; +use datafusion::{ + arrow::datatypes::SchemaRef, + datasource::TableProvider, + error::Result as DataFusionResult, + execution::TaskContext, + logical_expr::{Expr, TableProviderFilterPushDown, TableType}, + physical_plan::{ + stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, + PlanProperties, SendableRecordBatchStream, + }, + sql::TableReference, +}; + +pub struct TrinoTable { + pool: Arc, + pub(crate) base_table: SqlTable, &'static str>, +} + +impl std::fmt::Debug for TrinoTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TrinoTable") + .field("base_table", &self.base_table) + .finish() + } +} + +impl TrinoTable { + pub async fn new( + pool: &Arc, + table_reference: impl Into, + ) -> Result { + let dyn_pool = + Arc::clone(pool) as Arc, &'static str> + Send + Sync>; + + let base_table = SqlTable::new("schema", &dyn_pool, table_reference, Some(Engine::Trino)) + .await? + .with_dialect(Arc::new(DefaultDialect {})); + + Ok(Self { + pool: Arc::clone(pool), + base_table, + }) + } + + fn create_physical_plan( + &self, + projections: Option<&Vec>, + schema: &SchemaRef, + filters: &[Expr], + limit: Option, + ) -> DataFusionResult> { + Ok(Arc::new(TrinoSQLExec::new( + projections, + schema, + &self.base_table.table_reference, + Arc::clone(&self.pool), + filters, + limit, + )?)) + } +} + +#[async_trait] +impl TableProvider for TrinoTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.base_table.schema() + } + + fn table_type(&self) -> TableType { + self.base_table.table_type() + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> DataFusionResult> { + self.create_physical_plan(projection, &self.schema(), filters, limit) + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> DataFusionResult> { + self.base_table.supports_filters_pushdown(filters) + } +} + +impl Display for TrinoTable { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "TrinoTable {}", self.base_table.name()) + } +} + +struct TrinoSQLExec { + base_exec: SqlExec, &'static str>, +} + +impl TrinoSQLExec { + fn new( + projections: Option<&Vec>, + schema: &SchemaRef, + table_reference: &TableReference, + pool: Arc, + filters: &[Expr], + limit: Option, + ) -> DataFusionResult { + let base_exec = SqlExec::new( + projections, + schema, + table_reference, + pool, + filters, + limit, + Some(Engine::Trino), + )?; + + Ok(Self { base_exec }) + } + + fn sql(&self) -> SqlResult { + self.base_exec.sql() + } +} + +impl std::fmt::Debug for TrinoSQLExec { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let sql = self.sql().unwrap_or_default(); + write!(f, "TrinoSQLExec sql={sql}") + } +} + +impl DisplayAs for TrinoSQLExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result { + let sql = self.sql().unwrap_or_default(); + write!(f, "TrinoSQLExec sql={sql}") + } +} + +impl ExecutionPlan for TrinoSQLExec { + fn name(&self) -> &'static str { + "TrinoSQLExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.base_exec.schema() + } + + fn properties(&self) -> &PlanProperties { + self.base_exec.properties() + } + + fn children(&self) -> Vec<&Arc> { + self.base_exec.children() + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> DataFusionResult> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> DataFusionResult { + let sql = self.sql().map_err(to_execution_error)?; + tracing::debug!("TrinoSQLExec sql: {sql}"); + + let fut = get_stream(self.base_exec.clone_pool(), sql, Arc::clone(&self.schema())); + + let stream = futures::stream::once(fut).try_flatten(); + let schema = Arc::clone(&self.schema()); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) + } +} diff --git a/tests/integration.rs b/tests/integration.rs index 5f635855..3760592c 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -14,6 +14,8 @@ mod mysql; mod postgres; #[cfg(feature = "sqlite")] mod sqlite; +#[cfg(feature = "trino")] +mod trino; fn container_registry() -> String { std::env::var("CONTAINER_REGISTRY") diff --git a/tests/trino/common.rs b/tests/trino/common.rs new file mode 100644 index 00000000..adedfef2 --- /dev/null +++ b/tests/trino/common.rs @@ -0,0 +1,207 @@ +use crate::docker::{ContainerRunnerBuilder, RunningContainer}; +use bollard::secret::HealthConfig; +use datafusion_table_providers::sql::db_connection_pool::trinodbpool::TrinoConnectionPool; +use reqwest::header::HeaderMap; +use reqwest::Client; +use secrecy::SecretString; +use serde_json::Value; +use std::collections::HashMap; +use std::time::Duration; +use tokio::time::sleep; +use tracing::instrument; + +const TRINO_DOCKER_CONTAINER: &str = "runtime-integration-test-trino"; + +#[instrument] +pub async fn start_trino_docker_container(port: usize) -> Result { + let container_name = format!("{TRINO_DOCKER_CONTAINER}-{port}"); + + let port = port.try_into().unwrap_or(8080); + + let trino_docker_image = + std::env::var("TRINO_DOCKER_IMAGE").unwrap_or("trinodb/trino:latest".to_string()); + + let running_container = ContainerRunnerBuilder::new(container_name) + .image(trino_docker_image) + .add_port_binding(8080, port) + .healthcheck(HealthConfig { + test: Some(vec![ + "CMD".to_string(), + "curl".to_string(), + "-f".to_string(), + format!("http://localhost:8080/v1/info"), + ]), + interval: Some(2_000_000_000), // 2 seconds + timeout: Some(5_000_000_000), // 5 seconds + retries: Some(15), + start_period: Some(30_000_000_000), // 30 seconds + start_interval: None, + }) + .build()? + .run() + .await?; + + tokio::time::sleep(std::time::Duration::from_secs(15)).await; + Ok(running_container) +} + +pub(super) fn get_trino_params(port: usize) -> HashMap { + let mut params = HashMap::new(); + params.insert( + "trino_host".to_string(), + SecretString::from("localhost".to_string()), + ); + params.insert( + "trino_port".to_string(), + SecretString::from(port.to_string()), + ); + params.insert( + "trino_catalog".to_string(), + SecretString::from("tpch".to_string()), + ); + params.insert( + "trino_schema".to_string(), + SecretString::from("tiny".to_string()), + ); + params.insert( + "trino_user".to_string(), + SecretString::from("test".to_string()), + ); + params.insert( + "trino_sslmode".to_string(), + SecretString::from("disabled".to_string()), + ); + params +} + +#[instrument] +pub(super) async fn get_trino_connection_pool( + port: usize, +) -> Result { + let trino_pool = TrinoConnectionPool::new(get_trino_params(port)) + .await + .expect("Failed to create Trino Connection Pool"); + + Ok(trino_pool) +} + +pub struct TrinoClient { + reqwest_client: Client, + base_url: String, +} + +impl TrinoClient { + pub fn new(port: usize) -> Self { + let mut headers = HeaderMap::new(); + headers.insert("X-Trino-Catalog", "memory".parse().unwrap()); + headers.insert("X-Trino-Schema", "default".parse().unwrap()); + headers.insert("X-Trino-User", "test".parse().unwrap()); + + let client = Client::builder().default_headers(headers).build().unwrap(); + + Self { + reqwest_client: client, + base_url: format!("http://localhost:{port}"), + } + } + + pub async fn execute(&self, query: &str) -> Result>, anyhow::Error> { + let url = format!("{}/v1/statement", self.base_url); + let response = self + .reqwest_client + .post(&url) + .body(query.to_string()) + .send() + .await?; + + if !response.status().is_success() { + return Err(anyhow::anyhow!( + "Failed to submit query: HTTP {}: {}", + response.status(), + response.text().await? + )); + } + + let mut result: Value = response.json().await?; + let mut all_data = Vec::new(); + + loop { + let state = result["stats"]["state"].as_str().unwrap_or(""); + + // Extract data rows + if let Some(data) = result.get("data").and_then(|d| d.as_array()) { + for row in data { + if let Some(row_array) = row.as_array() { + all_data.push(row_array.clone()); + } + } + } + + // Check if query is finished + if state == "FINISHED" { + break; + } + + if let Some(next_uri) = result.get("nextUri").and_then(|u| u.as_str()) { + // Wait before polling + sleep(Duration::from_millis(50)).await; + + let response = self.reqwest_client.clone().get(next_uri).send().await?; + + if !response.status().is_success() { + let status_code = response.status().as_u16(); + let message = response.text().await.unwrap_or_default(); + return Err(anyhow::anyhow!( + "Failed to submit query: HTTP {}: {}", + status_code, + message + )); + } + + result = response.json().await?; + } else { + if state != "FINISHED" { + // No next URI but query not finished - this shouldn't happen + return Err(anyhow::anyhow!( + "Query not finished but no nextUri provided. State: {}", + state + )); + } + break; + } + } + + Ok(all_data) + } +} + +pub(super) async fn get_trino_client(port: usize) -> Result { + let client = TrinoClient::new(port); + + // Test connection and setup memory catalog + let mut retries = 15; + let mut last_err = None; + while retries > 0 { + match client.execute("SELECT 1").await { + Ok(_) => { + println!("Trino client connected successfully"); + return Ok(client); + } + Err(e) => { + last_err = Some(e); + tokio::time::sleep(std::time::Duration::from_millis(1000)).await; + retries -= 1; + println!("Trino connection test failed, retrying..."); + } + } + } + + if let Some(err) = last_err { + return Err(anyhow::anyhow!( + "Failed to connect to Trino after retries: {}", + err + )); + } + + Ok(client) +} diff --git a/tests/trino/mod.rs b/tests/trino/mod.rs new file mode 100644 index 00000000..4b5e0d7f --- /dev/null +++ b/tests/trino/mod.rs @@ -0,0 +1,515 @@ +use datafusion::{error::DataFusionError, execution::context::SessionContext}; +use rstest::rstest; +use std::sync::Arc; + +use arrow::{ + array::*, + datatypes::{DataType, Field, Schema, TimeUnit}, +}; + +use crate::docker::RunningContainer; + +mod common; + +async fn test_trino_datetime_types(port: usize) { + let client = common::get_trino_client(port) + .await + .expect("Trino client should be created"); + + let create_table_sql = r#" + CREATE TABLE memory.default.datetime_table ( + timestamp_field TIMESTAMP, + timestamp_with_tz TIMESTAMP WITH TIME ZONE, + date_field DATE, + time_field TIME + ) + "#; + + client + .execute(create_table_sql) + .await + .expect("Table should be created"); + + let insert_sql = r#" + INSERT INTO memory.default.datetime_table VALUES + (TIMESTAMP '2024-09-12 10:00:00', TIMESTAMP '2024-09-12 10:00:00 UTC', DATE '2024-09-12', TIME '10:00:00') + "#; + + client + .execute(insert_sql) + .await + .expect("Data should be inserted"); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "timestamp_field", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "timestamp_with_tz", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + true, + ), + Field::new("date_field", DataType::Date32, true), + Field::new("time_field", DataType::Time32(TimeUnit::Millisecond), true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(TimestampMillisecondArray::from(vec![1_726_135_200_000])), + Arc::new(TimestampMillisecondArray::from(vec![1_726_135_200_000]).with_timezone("UTC")), + Arc::new(Date32Array::from(vec![19978])), + Arc::new(Time32MillisecondArray::from(vec![36_000_000])), + ], + ) + .expect("Failed to create arrow record batch"); + + arrow_trino_one_way(port, "datetime_table", expected_record).await; +} + +async fn test_trino_numeric_types(port: usize) { + let client = common::get_trino_client(port) + .await + .expect("Trino client should be created"); + + let create_table_sql = r#" + CREATE TABLE memory.default.numeric_table ( + int_field INTEGER, + bigint_field BIGINT, + double_field DOUBLE, + decimal_field DECIMAL(18,6) + ) + "#; + + client + .execute(create_table_sql) + .await + .expect("Table should be created"); + + let insert_sql = r#" + INSERT INTO memory.default.numeric_table VALUES + (2147483647, 9223372036854775807, 3.14159265359, DECIMAL '123.456000') + "#; + + client + .execute(insert_sql) + .await + .expect("Data should be inserted"); + + let schema = Arc::new(Schema::new(vec![ + Field::new("int_field", DataType::Int32, true), + Field::new("bigint_field", DataType::Int64, true), + Field::new("double_field", DataType::Float64, true), + Field::new("decimal_field", DataType::Decimal128(18, 6), true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![2147483647i32])), + Arc::new(Int64Array::from(vec![9223372036854775807i64])), + Arc::new(Float64Array::from(vec![3.14159265359])), + Arc::new( + Decimal128Array::from(vec![Some(123456000i128)]) + .with_precision_and_scale(18, 6) + .unwrap(), + ), + ], + ) + .expect("Failed to create arrow record batch"); + + arrow_trino_one_way(port, "numeric_table", expected_record).await; +} + +async fn test_trino_string_types(port: usize) { + let client = common::get_trino_client(port) + .await + .expect("Trino client should be created"); + + let create_table_sql = r#" + CREATE TABLE memory.default.string_table ( + name VARCHAR, + description VARCHAR, + notes VARCHAR + ) + "#; + + client + .execute(create_table_sql) + .await + .expect("Table should be created"); + + let insert_sql = r#" + INSERT INTO memory.default.string_table VALUES + ('Alice', 'Software Engineer', NULL), + ('Bob', 'Data Scientist', 'Likes Trino') + "#; + + client + .execute(insert_sql) + .await + .expect("Data should be inserted"); + + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("description", DataType::Utf8, true), + Field::new("notes", DataType::Utf8, true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec!["Alice", "Bob"])), + Arc::new(StringArray::from(vec![ + "Software Engineer", + "Data Scientist", + ])), + Arc::new(StringArray::from(vec![None, Some("Likes Trino")])), + ], + ) + .expect("Failed to create arrow record batch"); + + arrow_trino_one_way(port, "string_table", expected_record).await; +} + +async fn test_trino_boolean_types(port: usize) { + let client = common::get_trino_client(port) + .await + .expect("Trino client should be created"); + + let create_table_sql = r#" + CREATE TABLE memory.default.boolean_table ( + is_active BOOLEAN, + is_verified BOOLEAN, + is_premium BOOLEAN + ) + "#; + + client + .execute(create_table_sql) + .await + .expect("Table should be created"); + + let insert_sql = r#" + INSERT INTO memory.default.boolean_table VALUES + (true, false, NULL), + (false, true, true) + "#; + + client + .execute(insert_sql) + .await + .expect("Data should be inserted"); + + let schema = Arc::new(Schema::new(vec![ + Field::new("is_active", DataType::Boolean, true), + Field::new("is_verified", DataType::Boolean, true), + Field::new("is_premium", DataType::Boolean, true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(BooleanArray::from(vec![true, false])), + Arc::new(BooleanArray::from(vec![false, true])), + Arc::new(BooleanArray::from(vec![None, Some(true)])), + ], + ) + .expect("Failed to create arrow record batch"); + + arrow_trino_one_way(port, "boolean_table", expected_record).await; +} + +async fn test_trino_binary_types(port: usize) { + let client = common::get_trino_client(port) + .await + .expect("Trino client should be created"); + + let create_table_sql = r#" + CREATE TABLE memory.default.binary_table ( + binary_data VARBINARY, + file_content VARBINARY + ) + "#; + + client + .execute(create_table_sql) + .await + .expect("Table should be created"); + + let insert_sql = r#" + INSERT INTO memory.default.binary_table VALUES + (X'68656c6c6f20776f726c64', X'62696e6172792066696c6520636f6e74656e74') + "#; + + client + .execute(insert_sql) + .await + .expect("Data should be inserted"); + + let schema = Arc::new(Schema::new(vec![ + Field::new("binary_data", DataType::Binary, true), + Field::new("file_content", DataType::Binary, true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(BinaryArray::from_vec(vec![b"hello world"])), + Arc::new(BinaryArray::from_vec(vec![b"binary file content"])), + ], + ) + .expect("Failed to create arrow record batch"); + + arrow_trino_one_way(port, "binary_table", expected_record).await; +} + +async fn test_trino_array_types(port: usize) { + let client = common::get_trino_client(port) + .await + .expect("Trino client should be created"); + + let create_table_sql = r#" + CREATE TABLE memory.default.array_table ( + string_tags ARRAY(VARCHAR), + int_numbers ARRAY(INTEGER), + empty_array ARRAY(VARCHAR) + ) + "#; + + client + .execute(create_table_sql) + .await + .expect("Table should be created"); + + let insert_sql = r#" + INSERT INTO memory.default.array_table VALUES + (ARRAY['rust', 'trino', 'arrow'], ARRAY[1, 2, 3], ARRAY[]), + (ARRAY['python', 'sql'], ARRAY[4, 5], ARRAY[]) + "#; + + client + .execute(insert_sql) + .await + .expect("Data should be inserted"); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "string_tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ), + Field::new( + "int_numbers", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + Field::new( + "empty_array", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ), + ])); + + let string_tags_builder = ListBuilder::new(StringBuilder::new()); + let mut string_tags_list = string_tags_builder; + + string_tags_list.values().append_value("rust"); + string_tags_list.values().append_value("trino"); + string_tags_list.values().append_value("arrow"); + string_tags_list.append(true); + + string_tags_list.values().append_value("python"); + string_tags_list.values().append_value("sql"); + string_tags_list.append(true); + + let string_tags_array = Arc::new(string_tags_list.finish()); + + let int_numbers_builder = ListBuilder::new(Int32Builder::new()); + let mut int_numbers_list = int_numbers_builder; + + int_numbers_list.values().append_value(1); + int_numbers_list.values().append_value(2); + int_numbers_list.values().append_value(3); + int_numbers_list.append(true); + + int_numbers_list.values().append_value(4); + int_numbers_list.values().append_value(5); + int_numbers_list.append(true); + + let int_numbers_array = Arc::new(int_numbers_list.finish()); + + let empty_array_builder = ListBuilder::new(StringBuilder::new()); + let mut empty_array_list = empty_array_builder; + + empty_array_list.append(true); + empty_array_list.append(true); + + let empty_array_array = Arc::new(empty_array_list.finish()); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![string_tags_array, int_numbers_array, empty_array_array], + ) + .expect("Failed to create arrow record batch"); + + arrow_trino_one_way(port, "array_table", expected_record).await; +} + +async fn test_trino_null_and_missing_fields(port: usize) { + let client = common::get_trino_client(port) + .await + .expect("Trino client should be created"); + + let create_table_sql = r#" + CREATE TABLE memory.default.null_fields_table ( + name VARCHAR, + age INTEGER, + email VARCHAR, + phone VARCHAR + ) + "#; + + client + .execute(create_table_sql) + .await + .expect("Table should be created"); + + let insert_sql = r#" + INSERT INTO memory.default.null_fields_table VALUES + ('Alice', 30, 'alice@example.com', NULL), + ('Bob', NULL, NULL, '555-1234'), + ('Charlie', 25, NULL, NULL) + "#; + + client + .execute(insert_sql) + .await + .expect("Data should be inserted"); + + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::Int32, true), + Field::new("email", DataType::Utf8, true), + Field::new("phone", DataType::Utf8, true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])), + Arc::new(Int32Array::from(vec![Some(30), None, Some(25)])), + Arc::new(StringArray::from(vec![ + Some("alice@example.com"), + None, + None, + ])), + Arc::new(StringArray::from(vec![None, Some("555-1234"), None])), + ], + ) + .expect("Failed to create arrow record batch"); + + arrow_trino_one_way(port, "null_fields_table", expected_record).await; +} + +async fn arrow_trino_one_way( + port: usize, + table_name: &str, + expected_record: RecordBatch, +) -> Vec { + tracing::debug!("Running Trino tests on {table_name}"); + + let ctx = SessionContext::new(); + + let trino_conn_pool = common::get_trino_connection_pool(port) + .await + .expect("Trino connection pool should be created"); + + let table = TrinoTable::new( + &Arc::new(trino_conn_pool), + format!("memory.default.{table_name}"), + ) + .await + .expect("Table should be created"); + + ctx.register_table(table_name, Arc::new(table)) + .expect("Table should be registered"); + + let schema_ref = expected_record.schema(); + let expected_fields: Vec<&str> = schema_ref + .fields() + .iter() + .map(|f| f.name().as_str()) + .collect(); + + let projection = expected_fields + .iter() + .map(|c| format!("\"{c}\"")) + .collect::>() + .join(", "); + let sql = format!("SELECT {projection} FROM {table_name}"); + + let df = ctx + .sql(&sql) + .await + .expect("DataFrame should be created from query"); + + let record_batches = df.collect().await.expect("RecordBatch should be collected"); + assert_eq!(record_batches.len(), 1); + + let actual_projected = + project_record_batch(&record_batches[0], &expected_fields).expect("Project actual"); + let expected_projected = + project_record_batch(&expected_record, &expected_fields).expect("Project expected"); + + assert_eq!(actual_projected, expected_projected); + + record_batches +} + +use datafusion::common::Result as DFResult; +use datafusion_table_providers::trino::sql_table::TrinoTable; + +fn project_record_batch(batch: &RecordBatch, columns: &[&str]) -> DFResult { + let schema = batch.schema(); + let indices: Vec = columns + .iter() + .map(|col| schema.index_of(col).expect("Column not found")) + .collect(); + let arrays = indices.iter().map(|&i| batch.column(i).clone()).collect(); + let fields = indices + .iter() + .map(|&i| schema.field(i).clone()) + .collect::>(); + let projected_schema = Arc::new(arrow::datatypes::Schema::new(fields)); + RecordBatch::try_new(projected_schema, arrays).map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) +} + +async fn start_trino_container(port: usize) -> RunningContainer { + let running_container = common::start_trino_docker_container(port) + .await + .expect("Trino container to start"); + + tracing::debug!("Trino Container started"); + + running_container +} + +#[rstest] +#[test_log::test(tokio::test)] +async fn test_trino_arrow_oneway() { + let port = crate::get_random_port(); + let trino_container = start_trino_container(port).await; + + test_trino_datetime_types(port).await; + test_trino_numeric_types(port).await; + test_trino_string_types(port).await; + test_trino_boolean_types(port).await; + test_trino_binary_types(port).await; + test_trino_array_types(port).await; + test_trino_null_and_missing_fields(port).await; + + trino_container.remove().await.expect("container to stop"); +}