Skip to content

Commit 3c9308b

Browse files
committed
fix(cubesql): properly convert type from arrow to dataframe
1 parent 0604df1 commit 3c9308b

File tree

1 file changed

+226
-17
lines changed

1 file changed

+226
-17
lines changed

rust/cubesql/cubesql/src/sql/dataframe.rs

Lines changed: 226 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -409,14 +409,15 @@ pub fn arrow_to_column_type(arrow_type: DataType) -> Result<ColumnType, CubeErro
409409
DataType::Float16 | DataType::Float32 | DataType::Float64 => Ok(ColumnType::Double),
410410
DataType::Boolean => Ok(ColumnType::Boolean),
411411
DataType::List(field) => Ok(ColumnType::List(field)),
412-
DataType::Int32 | DataType::UInt32 => Ok(ColumnType::Int32),
413412
DataType::Decimal(_, _) => Ok(ColumnType::Int32),
414-
DataType::Int8
415-
| DataType::Int16
416-
| DataType::Int64
417-
| DataType::UInt8
413+
DataType::Int8 //we are missing TableValue::Int8 type to use ColumnType:Int8
414+
| DataType::UInt8 //we are missing ColumnType::Int16 type
415+
| DataType::Int16 //we are missing ColumnType::Int16 type
418416
| DataType::UInt16
419-
| DataType::UInt64 => Ok(ColumnType::Int64),
417+
| DataType::Int32 => Ok(ColumnType::Int32),
418+
DataType::UInt32
419+
| DataType::Int64 => Ok(ColumnType::Int64),
420+
DataType::UInt64 => Ok(ColumnType::Decimal(128, 0)),
420421
DataType::Null => Ok(ColumnType::String),
421422
x => Err(CubeError::internal(format!("unsupported type {:?}", x))),
422423
}
@@ -452,12 +453,23 @@ pub fn batches_to_dataframe(
452453
let array = batch.column(column_index);
453454
let num_rows = batch.num_rows();
454455
match array.data_type() {
455-
DataType::UInt16 => convert_array!(array, num_rows, rows, UInt16Array, Int16, i16),
456+
DataType::Int8 => convert_array!(array, num_rows, rows, Int8Array, Int16, i16),
457+
DataType::UInt8 => convert_array!(array, num_rows, rows, UInt8Array, Int16, i16),
456458
DataType::Int16 => convert_array!(array, num_rows, rows, Int16Array, Int16, i16),
457-
DataType::UInt32 => convert_array!(array, num_rows, rows, UInt32Array, Int32, i32),
459+
DataType::UInt16 => convert_array!(array, num_rows, rows, UInt16Array, Int32, i32),
458460
DataType::Int32 => convert_array!(array, num_rows, rows, Int32Array, Int32, i32),
459-
DataType::UInt64 => convert_array!(array, num_rows, rows, UInt64Array, Int64, i64),
461+
DataType::UInt32 => convert_array!(array, num_rows, rows, UInt32Array, Int64, i64),
460462
DataType::Int64 => convert_array!(array, num_rows, rows, Int64Array, Int64, i64),
463+
DataType::UInt64 => {
464+
let a = array.as_any().downcast_ref::<UInt64Array>().unwrap();
465+
for i in 0..num_rows {
466+
rows[i].push(if a.is_null(i) {
467+
TableValue::Null
468+
} else {
469+
TableValue::Decimal128(Decimal128Value::new(a.value(i) as i128, 0))
470+
});
471+
}
472+
}
461473
DataType::Boolean => {
462474
convert_array!(array, num_rows, rows, BooleanArray, Boolean, bool)
463475
}
@@ -685,7 +697,16 @@ pub fn batches_to_dataframe(
685697

686698
#[cfg(test)]
687699
mod tests {
700+
use std::sync::Arc;
701+
702+
use datafusion::arrow::{array::PrimitiveArray, datatypes::TimestampMicrosecondType};
703+
use itertools::Itertools;
704+
688705
use super::*;
706+
use crate::compile::arrow::{
707+
datatypes::{ArrowPrimitiveType, Field, TimestampNanosecondType},
708+
record_batch::RecordBatchOptions,
709+
};
689710

690711
#[test]
691712
fn test_dataframe_print() {
@@ -803,20 +824,26 @@ mod tests {
803824
(DataType::LargeUtf8, ColumnType::String),
804825
(DataType::Date32, ColumnType::Date(false)),
805826
(DataType::Date64, ColumnType::Date(true)),
806-
(DataType::Timestamp(TimeUnit::Second, None), ColumnType::Timestamp),
807-
(DataType::Interval(IntervalUnit::YearMonth), ColumnType::Interval(IntervalUnit::YearMonth)),
827+
(
828+
DataType::Timestamp(TimeUnit::Second, None),
829+
ColumnType::Timestamp,
830+
),
831+
(
832+
DataType::Interval(IntervalUnit::YearMonth),
833+
ColumnType::Interval(IntervalUnit::YearMonth),
834+
),
808835
(DataType::Float16, ColumnType::Double),
809836
(DataType::Float32, ColumnType::Double),
810837
(DataType::Float64, ColumnType::Double),
811838
(DataType::Boolean, ColumnType::Boolean),
812839
(DataType::Int32, ColumnType::Int32),
813-
(DataType::UInt32, ColumnType::Int32),
814-
(DataType::Int8, ColumnType::Int64),
815-
(DataType::Int16, ColumnType::Int64),
840+
(DataType::UInt32, ColumnType::Int64),
841+
(DataType::Int8, ColumnType::Int32),
842+
(DataType::Int16, ColumnType::Int32),
816843
(DataType::Int64, ColumnType::Int64),
817-
(DataType::UInt8, ColumnType::Int64),
818-
(DataType::UInt16, ColumnType::Int64),
819-
(DataType::UInt64, ColumnType::Int64),
844+
(DataType::UInt8, ColumnType::Int32),
845+
(DataType::UInt16, ColumnType::Int32),
846+
(DataType::UInt64, ColumnType::Decimal(128, 0)),
820847
(DataType::Null, ColumnType::String),
821848
];
822849

@@ -825,4 +852,186 @@ mod tests {
825852
assert_eq!(result, expected_column_type, "Failed for {:?}", arrow_type);
826853
}
827854
}
855+
856+
fn create_record_batch<T: ArrowPrimitiveType>(
857+
data_type: DataType,
858+
value: PrimitiveArray<T>,
859+
expected_data_type: ColumnType,
860+
expected_data: Vec<TableValue>,
861+
) -> Result<(), CubeError> {
862+
let batch = RecordBatch::try_new_with_options(
863+
Arc::new(Schema::new(vec![Field::new("data", data_type, false)])),
864+
vec![Arc::new(value)],
865+
&RecordBatchOptions::default(),
866+
)
867+
.map_err(|e| CubeError::from(e))?;
868+
869+
let df = batches_to_dataframe(&batch.schema(), vec![batch.clone()])?;
870+
let colums = df.get_columns().clone();
871+
let data = df.data;
872+
assert_eq!(
873+
colums.len(),
874+
1,
875+
"Expecting one column in DF, but: {:?}",
876+
colums
877+
);
878+
assert_eq!(expected_data_type, colums.get(0).unwrap().column_type);
879+
assert_eq!(
880+
data.len(),
881+
expected_data.len(),
882+
"Expecting {} columns in DF data, but: {:?}",
883+
expected_data.len(),
884+
data
885+
);
886+
let vec1 = data.into_iter().map(|r| r.values).flatten().collect_vec();
887+
assert_eq!(
888+
vec1.len(),
889+
expected_data.len(),
890+
"Data len {} != {}",
891+
vec1.len(),
892+
expected_data.len()
893+
);
894+
assert_eq!(vec1, expected_data);
895+
Ok(())
896+
}
897+
898+
#[test]
899+
fn test_timestamp_conversion() -> Result<(), CubeError> {
900+
let data_nano = vec![TimestampNanosecondType::default_value()];
901+
create_record_batch(
902+
DataType::Timestamp(TimeUnit::Nanosecond, None),
903+
TimestampNanosecondArray::from(data_nano.clone()),
904+
ColumnType::Timestamp, //here we are missing TableValue::Int8 to use ColumnType::Int32
905+
data_nano
906+
.into_iter()
907+
.map(|e| TableValue::Timestamp(TimestampValue::new(e, None)))
908+
.collect::<Vec<_>>(),
909+
)?;
910+
911+
let data_micro = vec![TimestampMicrosecondType::default_value()];
912+
create_record_batch(
913+
DataType::Timestamp(TimeUnit::Microsecond, None),
914+
TimestampMicrosecondArray::from(data_micro.clone()),
915+
ColumnType::Timestamp, //here we are missing TableValue::Int8 to use ColumnType::Int32
916+
data_micro
917+
.into_iter()
918+
.map(|e| TableValue::Timestamp(TimestampValue::new(e * 1000, None)))
919+
.collect::<Vec<_>>(),
920+
)
921+
}
922+
923+
#[test]
924+
fn test_signed_conversion() -> Result<(), CubeError> {
925+
let data8 = vec![i8::MIN, -1, 0, 1, 2, 3, 4, i8::MAX];
926+
create_record_batch(
927+
DataType::Int8,
928+
Int8Array::from(data8.clone()),
929+
ColumnType::Int32, //here we are missing TableValue::Int8 to use ColumnType::Int32
930+
data8
931+
.into_iter()
932+
.map(|e| TableValue::Int16(e as i16))
933+
.collect::<Vec<_>>(),
934+
)?;
935+
936+
let data16 = vec![i16::MIN, -1, 0, 1, 2, 3, 4, i16::MAX];
937+
create_record_batch(
938+
DataType::Int16,
939+
Int16Array::from(data16.clone()),
940+
ColumnType::Int32, //here we are missing ColumnType::Int16
941+
data16
942+
.into_iter()
943+
.map(|e| TableValue::Int16(e))
944+
.collect::<Vec<_>>(),
945+
)?;
946+
947+
let data32 = vec![i32::MIN, -1, 0, 1, 2, 3, 4, i32::MAX];
948+
create_record_batch(
949+
DataType::Int32,
950+
Int32Array::from(data32.clone()),
951+
ColumnType::Int32,
952+
data32
953+
.into_iter()
954+
.map(|e| TableValue::Int32(e))
955+
.collect::<Vec<_>>(),
956+
)?;
957+
958+
let data64 = vec![i64::MIN, -1, 0, 1, 2, 3, 4, i64::MAX];
959+
create_record_batch(
960+
DataType::Int64,
961+
Int64Array::from(data64.clone()),
962+
ColumnType::Int64,
963+
data64
964+
.into_iter()
965+
.map(|e| TableValue::Int64(e))
966+
.collect::<Vec<_>>(),
967+
)
968+
}
969+
970+
#[test]
971+
fn test_unsigned_conversion() -> Result<(), CubeError> {
972+
let data8 = vec![0, 1, 2, 3, 4, u8::MAX];
973+
create_record_batch(
974+
DataType::UInt8,
975+
UInt8Array::from(data8.clone()),
976+
ColumnType::Int32, //here we are missing ColumnType::Int16
977+
data8
978+
.into_iter()
979+
.map(|e| TableValue::Int16(e as i16))
980+
.collect::<Vec<_>>(),
981+
)?;
982+
983+
let data16 = vec![0, 1, 2, 3, 4, u16::MAX];
984+
create_record_batch(
985+
DataType::UInt16,
986+
UInt16Array::from(data16.clone()),
987+
ColumnType::Int32,
988+
data16
989+
.into_iter()
990+
.map(|e| TableValue::Int32(e as i32))
991+
.collect::<Vec<_>>(),
992+
)?;
993+
994+
let data32 = vec![0, 1, 2, 3, 4, u32::MAX];
995+
create_record_batch(
996+
DataType::UInt32,
997+
UInt32Array::from(data32.clone()),
998+
ColumnType::Int64,
999+
data32
1000+
.into_iter()
1001+
.map(|e| TableValue::Int64(e as i64))
1002+
.collect::<Vec<_>>(),
1003+
)?;
1004+
1005+
let data64 = vec![0, 1, 2, 3, 4, u64::MAX];
1006+
create_record_batch(
1007+
DataType::UInt64,
1008+
UInt64Array::from(data64.clone()),
1009+
ColumnType::Decimal(128, 0),
1010+
data64
1011+
.into_iter()
1012+
.map(|e| TableValue::Decimal128(Decimal128Value::new(e as i128, 0)))
1013+
.collect::<Vec<_>>(),
1014+
)
1015+
}
1016+
1017+
impl PartialEq for TableValue {
1018+
fn eq(&self, other: &Self) -> bool {
1019+
match (self, other) {
1020+
(TableValue::Null, TableValue::Null) => true,
1021+
(TableValue::String(a), TableValue::String(b)) => a == b,
1022+
(TableValue::Int16(a), TableValue::Int16(b)) => a == b,
1023+
(TableValue::Int32(a), TableValue::Int32(b)) => a == b,
1024+
(TableValue::Int64(a), TableValue::Int64(b)) => a == b,
1025+
(TableValue::Boolean(a), TableValue::Boolean(b)) => a == b,
1026+
(TableValue::Float32(a), TableValue::Float32(b)) => a == b,
1027+
(TableValue::Float64(a), TableValue::Float64(b)) => a == b,
1028+
(TableValue::List(_), TableValue::List(_)) => panic!("unsupported"),
1029+
(TableValue::Decimal128(a), TableValue::Decimal128(b)) => a == b,
1030+
(TableValue::Date(a), TableValue::Date(b)) => a == b,
1031+
(TableValue::Timestamp(a), TableValue::Timestamp(b)) => a == b,
1032+
(TableValue::Interval(_), TableValue::Interval(_)) => panic!("unsupported"),
1033+
_ => false,
1034+
}
1035+
}
1036+
}
8281037
}

0 commit comments

Comments
 (0)