Skip to content

Commit f85893f

Browse files
authored
Support more time types to arrow vtab (#289)
* add more time types to arrow vtab * clippy * properly support non-tz timestamps * dont compare timezones
1 parent b82db39 commit f85893f

File tree

2 files changed

+207
-17
lines changed

2 files changed

+207
-17
lines changed

src/vtab/arrow.rs

Lines changed: 204 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use super::{
66
use crate::vtab::vector::Inserter;
77
use arrow::array::{
88
as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, Array, ArrayData,
9-
BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
10-
StructArray,
9+
AsArray, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray,
10+
StringArray, StructArray,
1111
};
1212

1313
use arrow::{
@@ -138,9 +138,15 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
138138
DataType::UInt64 => UBigint,
139139
DataType::Float32 => Float,
140140
DataType::Float64 => Double,
141-
DataType::Timestamp(_, _) => Timestamp,
142-
DataType::Date32 => Time,
143-
DataType::Date64 => Time,
141+
DataType::Timestamp(unit, None) => match unit {
142+
TimeUnit::Second => TimestampS,
143+
TimeUnit::Millisecond => TimestampMs,
144+
TimeUnit::Microsecond => Timestamp,
145+
TimeUnit::Nanosecond => TimestampNs,
146+
},
147+
DataType::Timestamp(_, Some(_)) => TimestampTZ,
148+
DataType::Date32 => Date,
149+
DataType::Date64 => Date,
144150
DataType::Time32(_) => Time,
145151
DataType::Time64(_) => Time,
146152
DataType::Duration(_) => Interval,
@@ -250,6 +256,16 @@ fn primitive_array_to_flat_vector<T: ArrowPrimitiveType>(array: &PrimitiveArray<
250256
out_vector.copy::<T::Native>(array.values());
251257
}
252258

259+
fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
260+
data_type: DataType,
261+
array: &dyn Array,
262+
out_vector: &mut dyn Vector,
263+
) {
264+
let array = arrow::compute::kernels::cast::cast(array, &data_type).unwrap();
265+
let out_vector: &mut FlatVector = out_vector.as_mut_any().downcast_mut().unwrap();
266+
out_vector.copy::<T::Native>(array.as_primitive::<T>().values());
267+
}
268+
253269
fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
254270
match array.data_type() {
255271
DataType::Boolean => {
@@ -303,6 +319,7 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
303319
out.as_mut_any().downcast_mut().unwrap(),
304320
);
305321
}
322+
DataType::Float16 => todo!("Float16 is not supported yet"),
306323
DataType::Float32 => {
307324
primitive_array_to_flat_vector::<Float32Type>(
308325
as_primitive_array(array),
@@ -324,22 +341,55 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
324341
out.as_mut_any().downcast_mut().unwrap(),
325342
);
326343
}
327-
// DataType::Decimal256(_, _) => {
328-
// primitive_array_to_flat_vector::<Decimal256Type>(
329-
// as_primitive_array(array),
330-
// out.as_mut_any().downcast_mut().unwrap(),
331-
// );
332-
// }
333-
_ => {
334-
todo!()
344+
DataType::Decimal256(_, _) => todo!("Decimal256 is not supported yet"),
345+
346+
// DuckDB Only supports timetamp_tz in microsecond precision
347+
DataType::Timestamp(_, Some(tz)) => primitive_array_to_flat_vector_cast::<TimestampMicrosecondType>(
348+
DataType::Timestamp(TimeUnit::Microsecond, Some(tz.clone())),
349+
array,
350+
out,
351+
),
352+
DataType::Timestamp(unit, None) => match unit {
353+
TimeUnit::Second => primitive_array_to_flat_vector::<TimestampSecondType>(
354+
as_primitive_array(array),
355+
out.as_mut_any().downcast_mut().unwrap(),
356+
),
357+
TimeUnit::Millisecond => primitive_array_to_flat_vector::<TimestampMillisecondType>(
358+
as_primitive_array(array),
359+
out.as_mut_any().downcast_mut().unwrap(),
360+
),
361+
TimeUnit::Microsecond => primitive_array_to_flat_vector::<TimestampMicrosecondType>(
362+
as_primitive_array(array),
363+
out.as_mut_any().downcast_mut().unwrap(),
364+
),
365+
TimeUnit::Nanosecond => primitive_array_to_flat_vector::<TimestampNanosecondType>(
366+
as_primitive_array(array),
367+
out.as_mut_any().downcast_mut().unwrap(),
368+
),
369+
},
370+
DataType::Date32 => {
371+
primitive_array_to_flat_vector::<Date32Type>(
372+
as_primitive_array(array),
373+
out.as_mut_any().downcast_mut().unwrap(),
374+
);
375+
}
376+
DataType::Date64 => primitive_array_to_flat_vector_cast::<Date32Type>(Date32Type::DATA_TYPE, array, out),
377+
DataType::Time32(_) => {
378+
primitive_array_to_flat_vector_cast::<Time64MicrosecondType>(Time64MicrosecondType::DATA_TYPE, array, out)
335379
}
380+
DataType::Time64(_) => {
381+
primitive_array_to_flat_vector_cast::<Time64MicrosecondType>(Time64MicrosecondType::DATA_TYPE, array, out)
382+
}
383+
_ => todo!(
384+
"Converting '{dtype:#?}' to primitive flat vector is not supported",
385+
dtype = array.data_type()
386+
),
336387
}
337388
}
338389

339-
/// Convert Arrow [BooleanArray] to a duckdb vector.
390+
/// Convert Arrow [Decimal128Array] to a duckdb vector.
340391
fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector) {
341392
assert!(array.len() <= out.capacity());
342-
343393
for i in 0..array.len() {
344394
out.as_mut_slice()[i] = array.value_as_string(i).parse::<f64>().unwrap();
345395
}
@@ -488,8 +538,12 @@ mod test {
488538
use super::{arrow_recordbatch_to_query_params, ArrowVTab};
489539
use crate::{Connection, Result};
490540
use arrow::{
491-
array::{Float64Array, Int32Array},
492-
datatypes::{DataType, Field, Schema},
541+
array::{
542+
Array, AsArray, Date32Array, Date64Array, Float64Array, Int32Array, PrimitiveArray, StringArray,
543+
Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
544+
TimestampNanosecondArray, TimestampSecondArray,
545+
},
546+
datatypes::{ArrowPrimitiveType, DataType, Field, Schema},
493547
record_batch::RecordBatch,
494548
};
495549
use std::{error::Error, sync::Arc};
@@ -534,4 +588,137 @@ mod test {
534588
assert_eq!(column.value(0), 15);
535589
Ok(())
536590
}
591+
592+
fn check_rust_primitive_array_roundtrip<T1, T2>(
593+
input_array: PrimitiveArray<T1>,
594+
expected_array: PrimitiveArray<T2>,
595+
) -> Result<(), Box<dyn Error>>
596+
where
597+
T1: ArrowPrimitiveType,
598+
T2: ArrowPrimitiveType,
599+
{
600+
let db = Connection::open_in_memory()?;
601+
db.register_table_function::<ArrowVTab>("arrow")?;
602+
603+
// Roundtrip a record batch from Rust to DuckDB and back to Rust
604+
let schema = Schema::new(vec![Field::new("a", input_array.data_type().clone(), false)]);
605+
606+
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(input_array.clone())])?;
607+
let param = arrow_recordbatch_to_query_params(rb);
608+
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
609+
let rb = stmt.query_arrow(param)?.next().expect("no record batch");
610+
611+
let output_any_array = rb.column(0);
612+
match (output_any_array.data_type(), expected_array.data_type()) {
613+
// TODO: DuckDB doesnt return timestamp_tz properly yet, so we just check that the units are the same
614+
(DataType::Timestamp(unit_a, _), DataType::Timestamp(unit_b, _)) => assert_eq!(unit_a, unit_b),
615+
(a, b) => assert_eq!(a, b),
616+
}
617+
618+
let maybe_output_array = output_any_array.as_primitive_opt::<T2>();
619+
620+
match maybe_output_array {
621+
Some(output_array) => {
622+
// Check that the output array is the same as the input array
623+
assert_eq!(output_array.len(), expected_array.len());
624+
for i in 0..output_array.len() {
625+
assert_eq!(output_array.is_valid(i), expected_array.is_valid(i));
626+
if output_array.is_valid(i) {
627+
assert_eq!(output_array.value(i), expected_array.value(i));
628+
}
629+
}
630+
}
631+
None => {
632+
panic!("Output array is not a PrimitiveArray {:?}", rb.column(0).data_type());
633+
}
634+
}
635+
636+
Ok(())
637+
}
638+
639+
#[test]
640+
fn test_timestamp_roundtrip() -> Result<(), Box<dyn Error>> {
641+
check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;
642+
643+
check_rust_primitive_array_roundtrip(
644+
TimestampMicrosecondArray::from(vec![1, 2, 3]),
645+
TimestampMicrosecondArray::from(vec![1, 2, 3]),
646+
)?;
647+
648+
check_rust_primitive_array_roundtrip(
649+
TimestampNanosecondArray::from(vec![1, 2, 3]),
650+
TimestampNanosecondArray::from(vec![1, 2, 3]),
651+
)?;
652+
653+
check_rust_primitive_array_roundtrip(
654+
TimestampSecondArray::from(vec![1, 2, 3]),
655+
TimestampSecondArray::from(vec![1, 2, 3]),
656+
)?;
657+
658+
check_rust_primitive_array_roundtrip(
659+
TimestampMillisecondArray::from(vec![1, 2, 3]),
660+
TimestampMillisecondArray::from(vec![1, 2, 3]),
661+
)?;
662+
663+
// DuckDB can only return timestamp_tz in microseconds
664+
// Note: DuckDB by default returns timestamp_tz with UTC because the rust
665+
// driver doesnt support timestamp_tz properly when reading. In the
666+
// future we should be able to roundtrip timestamp_tz with other timezones too
667+
check_rust_primitive_array_roundtrip(
668+
TimestampNanosecondArray::from(vec![1000, 2000, 3000]).with_timezone_utc(),
669+
TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone_utc(),
670+
)?;
671+
672+
check_rust_primitive_array_roundtrip(
673+
TimestampMillisecondArray::from(vec![1, 2, 3]).with_timezone_utc(),
674+
TimestampMicrosecondArray::from(vec![1000, 2000, 3000]).with_timezone_utc(),
675+
)?;
676+
677+
check_rust_primitive_array_roundtrip(
678+
TimestampSecondArray::from(vec![1, 2, 3]).with_timezone_utc(),
679+
TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]).with_timezone_utc(),
680+
)?;
681+
682+
check_rust_primitive_array_roundtrip(
683+
TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone_utc(),
684+
TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone_utc(),
685+
)?;
686+
687+
check_rust_primitive_array_roundtrip(Date32Array::from(vec![1, 2, 3]), Date32Array::from(vec![1, 2, 3]))?;
688+
689+
let mid = arrow::temporal_conversions::MILLISECONDS_IN_DAY;
690+
check_rust_primitive_array_roundtrip(
691+
Date64Array::from(vec![mid, 2 * mid, 3 * mid]),
692+
Date32Array::from(vec![1, 2, 3]),
693+
)?;
694+
695+
check_rust_primitive_array_roundtrip(
696+
Time32SecondArray::from(vec![1, 2, 3]),
697+
Time64MicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]),
698+
)?;
699+
700+
Ok(())
701+
}
702+
703+
#[test]
704+
fn test_timestamp_tz_insert() -> Result<(), Box<dyn Error>> {
705+
// TODO: This test should be reworked once we support TIMESTAMP_TZ properly
706+
707+
let db = Connection::open_in_memory()?;
708+
db.register_table_function::<ArrowVTab>("arrow")?;
709+
710+
let array = TimestampMicrosecondArray::from(vec![1]).with_timezone("+05:00");
711+
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]);
712+
713+
// Since we cant get TIMESTAMP_TZ from the rust client yet, we just check that we can insert it properly here.
714+
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).expect("failed to create record batch");
715+
let param = arrow_recordbatch_to_query_params(rb);
716+
let mut stmt = db.prepare("select typeof(a)::VARCHAR from arrow(?, ?)")?;
717+
let mut arr = stmt.query_arrow(param)?;
718+
let rb = arr.next().expect("no record batch");
719+
assert_eq!(rb.num_columns(), 1);
720+
let column = rb.column(0).as_any().downcast_ref::<StringArray>().unwrap();
721+
assert_eq!(column.value(0), "TIMESTAMP WITH TIME ZONE");
722+
Ok(())
723+
}
537724
}

src/vtab/logical_type.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ pub enum LogicalTypeId {
6666
Uuid = DUCKDB_TYPE_DUCKDB_TYPE_UUID,
6767
/// Union
6868
Union = DUCKDB_TYPE_DUCKDB_TYPE_UNION,
69+
/// Timestamp TZ
70+
TimestampTZ = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ,
6971
}
7072

7173
impl From<u32> for LogicalTypeId {
@@ -100,6 +102,7 @@ impl From<u32> for LogicalTypeId {
100102
DUCKDB_TYPE_DUCKDB_TYPE_MAP => Self::Map,
101103
DUCKDB_TYPE_DUCKDB_TYPE_UUID => Self::Uuid,
102104
DUCKDB_TYPE_DUCKDB_TYPE_UNION => Self::Union,
105+
DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ => Self::TimestampTZ,
103106
_ => panic!(),
104107
}
105108
}

0 commit comments

Comments
 (0)