Skip to content

Commit 6e7d603

Browse files
committed
fix(cubesql): properly convert type from arrow to dataframe
1 parent 0604df1 commit 6e7d603

File tree

1 file changed

+231
-17
lines changed

1 file changed

+231
-17
lines changed

rust/cubesql/cubesql/src/sql/dataframe.rs

Lines changed: 231 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -409,14 +409,15 @@ pub fn arrow_to_column_type(arrow_type: DataType) -> Result<ColumnType, CubeErro
409409
DataType::Float16 | DataType::Float32 | DataType::Float64 => Ok(ColumnType::Double),
410410
DataType::Boolean => Ok(ColumnType::Boolean),
411411
DataType::List(field) => Ok(ColumnType::List(field)),
412-
DataType::Int32 | DataType::UInt32 => Ok(ColumnType::Int32),
413412
DataType::Decimal(_, _) => Ok(ColumnType::Int32),
414-
DataType::Int8
415-
| DataType::Int16
416-
| DataType::Int64
417-
| DataType::UInt8
413+
DataType::Int8 //we are missing TableValue::Int8 type to use ColumnType:Int8
414+
| DataType::UInt8 //we are missing ColumnType::Int16 type
415+
| DataType::Int16 //we are missing ColumnType::Int16 type
418416
| DataType::UInt16
419-
| DataType::UInt64 => Ok(ColumnType::Int64),
417+
| DataType::Int32 => Ok(ColumnType::Int32),
418+
DataType::UInt32
419+
| DataType::Int64 => Ok(ColumnType::Int64),
420+
DataType::UInt64 => Ok(ColumnType::Decimal(128, 0)),
420421
DataType::Null => Ok(ColumnType::String),
421422
x => Err(CubeError::internal(format!("unsupported type {:?}", x))),
422423
}
@@ -452,12 +453,23 @@ pub fn batches_to_dataframe(
452453
let array = batch.column(column_index);
453454
let num_rows = batch.num_rows();
454455
match array.data_type() {
455-
DataType::UInt16 => convert_array!(array, num_rows, rows, UInt16Array, Int16, i16),
456+
DataType::Int8 => convert_array!(array, num_rows, rows, Int8Array, Int16, i16),
457+
DataType::UInt8 => convert_array!(array, num_rows, rows, UInt8Array, Int16, i16),
456458
DataType::Int16 => convert_array!(array, num_rows, rows, Int16Array, Int16, i16),
457-
DataType::UInt32 => convert_array!(array, num_rows, rows, UInt32Array, Int32, i32),
459+
DataType::UInt16 => convert_array!(array, num_rows, rows, UInt16Array, Int32, i32),
458460
DataType::Int32 => convert_array!(array, num_rows, rows, Int32Array, Int32, i32),
459-
DataType::UInt64 => convert_array!(array, num_rows, rows, UInt64Array, Int64, i64),
461+
DataType::UInt32 => convert_array!(array, num_rows, rows, UInt32Array, Int64, i64),
460462
DataType::Int64 => convert_array!(array, num_rows, rows, Int64Array, Int64, i64),
463+
DataType::UInt64 => {
464+
let a = array.as_any().downcast_ref::<UInt64Array>().unwrap();
465+
for i in 0..num_rows {
466+
rows[i].push(if a.is_null(i) {
467+
TableValue::Null
468+
} else {
469+
TableValue::Decimal128(Decimal128Value::new(a.value(i) as i128, 0))
470+
});
471+
}
472+
}
461473
DataType::Boolean => {
462474
convert_array!(array, num_rows, rows, BooleanArray, Boolean, bool)
463475
}
@@ -685,7 +697,20 @@ pub fn batches_to_dataframe(
685697

686698
#[cfg(test)]
687699
mod tests {
700+
use std::sync::Arc;
701+
702+
use datafusion::{
703+
arrow::{array::PrimitiveArray, datatypes::TimestampMicrosecondType},
704+
scalar::ScalarValue::TimestampMicrosecond,
705+
};
706+
use itertools::Itertools;
707+
688708
use super::*;
709+
use crate::compile::arrow::{
710+
datatypes::{ArrowPrimitiveType, Field, TimestampMillisecondType, TimestampNanosecondType},
711+
error::ArrowError,
712+
record_batch::RecordBatchOptions,
713+
};
689714

690715
#[test]
691716
fn test_dataframe_print() {
@@ -803,20 +828,26 @@ mod tests {
803828
(DataType::LargeUtf8, ColumnType::String),
804829
(DataType::Date32, ColumnType::Date(false)),
805830
(DataType::Date64, ColumnType::Date(true)),
806-
(DataType::Timestamp(TimeUnit::Second, None), ColumnType::Timestamp),
807-
(DataType::Interval(IntervalUnit::YearMonth), ColumnType::Interval(IntervalUnit::YearMonth)),
831+
(
832+
DataType::Timestamp(TimeUnit::Second, None),
833+
ColumnType::Timestamp,
834+
),
835+
(
836+
DataType::Interval(IntervalUnit::YearMonth),
837+
ColumnType::Interval(IntervalUnit::YearMonth),
838+
),
808839
(DataType::Float16, ColumnType::Double),
809840
(DataType::Float32, ColumnType::Double),
810841
(DataType::Float64, ColumnType::Double),
811842
(DataType::Boolean, ColumnType::Boolean),
812843
(DataType::Int32, ColumnType::Int32),
813-
(DataType::UInt32, ColumnType::Int32),
814-
(DataType::Int8, ColumnType::Int64),
815-
(DataType::Int16, ColumnType::Int64),
844+
(DataType::UInt32, ColumnType::Int64),
845+
(DataType::Int8, ColumnType::Int32),
846+
(DataType::Int16, ColumnType::Int32),
816847
(DataType::Int64, ColumnType::Int64),
817-
(DataType::UInt8, ColumnType::Int64),
818-
(DataType::UInt16, ColumnType::Int64),
819-
(DataType::UInt64, ColumnType::Int64),
848+
(DataType::UInt8, ColumnType::Int32),
849+
(DataType::UInt16, ColumnType::Int32),
850+
(DataType::UInt64, ColumnType::Decimal(128, 0)),
820851
(DataType::Null, ColumnType::String),
821852
];
822853

@@ -825,4 +856,187 @@ mod tests {
825856
assert_eq!(result, expected_column_type, "Failed for {:?}", arrow_type);
826857
}
827858
}
859+
860+
fn create_record_batch<T: ArrowPrimitiveType>(
861+
data_type: DataType,
862+
value: PrimitiveArray<T>,
863+
expected_data_type: ColumnType,
864+
expected_data: Vec<TableValue>,
865+
) -> Result<(), CubeError> {
866+
let batch = RecordBatch::try_new_with_options(
867+
Arc::new(Schema::new(vec![Field::new("data", data_type, false)])),
868+
vec![Arc::new(value)],
869+
&RecordBatchOptions::default(),
870+
)
871+
.map_err(|e| CubeError::from(e))?;
872+
873+
let df = batches_to_dataframe(&batch.schema(), vec![batch.clone()])?;
874+
let colums = df.get_columns().clone();
875+
let data = df.data;
876+
assert_eq!(
877+
colums.len(),
878+
1,
879+
"Expecting one column in DF, but: {:?}",
880+
colums
881+
);
882+
assert_eq!(expected_data_type, colums.get(0).unwrap().column_type);
883+
assert_eq!(
884+
data.len(),
885+
expected_data.len(),
886+
"Expecting {} columns in DF data, but: {:?}",
887+
expected_data.len(),
888+
data
889+
);
890+
let vec1 = data.into_iter().map(|r| r.values).flatten().collect_vec();
891+
assert_eq!(
892+
vec1.len(),
893+
expected_data.len(),
894+
"Data len {} != {}",
895+
vec1.len(),
896+
expected_data.len()
897+
);
898+
assert_eq!(vec1, expected_data);
899+
Ok(())
900+
}
901+
902+
#[test]
903+
fn test_timestamp_conversion() -> Result<(), CubeError> {
904+
let data_nano = vec![TimestampNanosecondType::default_value()];
905+
create_record_batch(
906+
DataType::Timestamp(TimeUnit::Nanosecond, None),
907+
TimestampNanosecondArray::from(data_nano.clone()),
908+
ColumnType::Timestamp, //here we are missing TableValue::Int8 to use ColumnType::Int32
909+
data_nano
910+
.into_iter()
911+
.map(|e| TableValue::Timestamp(TimestampValue::new(e, None)))
912+
.collect::<Vec<_>>(),
913+
)?;
914+
915+
let data_micro = vec![TimestampMicrosecondType::default_value()];
916+
create_record_batch(
917+
DataType::Timestamp(TimeUnit::Microsecond, None),
918+
TimestampMicrosecondArray::from(data_micro.clone()),
919+
ColumnType::Timestamp, //here we are missing TableValue::Int8 to use ColumnType::Int32
920+
data_micro
921+
.into_iter()
922+
.map(|e| TableValue::Timestamp(TimestampValue::new(e * 1000, None)))
923+
.collect::<Vec<_>>(),
924+
)
925+
}
926+
927+
#[test]
928+
fn test_signed_conversion() -> Result<(), CubeError> {
929+
let data8 = vec![i8::MIN, -1, 0, 1, 2, 3, 4, i8::MAX];
930+
create_record_batch(
931+
DataType::Int8,
932+
Int8Array::from(data8.clone()),
933+
ColumnType::Int32, //here we are missing TableValue::Int8 to use ColumnType::Int32
934+
data8
935+
.into_iter()
936+
.map(|e| TableValue::Int16(e as i16))
937+
.collect::<Vec<_>>(),
938+
)?;
939+
940+
let data16 = vec![i16::MIN, -1, 0, 1, 2, 3, 4, i16::MAX];
941+
create_record_batch(
942+
DataType::Int16,
943+
Int16Array::from(data16.clone()),
944+
ColumnType::Int32, //here we are missing ColumnType::Int16
945+
data16
946+
.into_iter()
947+
.map(|e| TableValue::Int16(e))
948+
.collect::<Vec<_>>(),
949+
)?;
950+
951+
let data32 = vec![i32::MIN, -1, 0, 1, 2, 3, 4, i32::MAX];
952+
create_record_batch(
953+
DataType::Int32,
954+
Int32Array::from(data32.clone()),
955+
ColumnType::Int32,
956+
data32
957+
.into_iter()
958+
.map(|e| TableValue::Int32(e))
959+
.collect::<Vec<_>>(),
960+
)?;
961+
962+
let data64 = vec![i64::MIN, -1, 0, 1, 2, 3, 4, i64::MAX];
963+
create_record_batch(
964+
DataType::Int64,
965+
Int64Array::from(data64.clone()),
966+
ColumnType::Int64,
967+
data64
968+
.into_iter()
969+
.map(|e| TableValue::Int64(e))
970+
.collect::<Vec<_>>(),
971+
)
972+
}
973+
974+
#[test]
975+
fn test_unsigned_conversion() -> Result<(), CubeError> {
976+
let data8 = vec![0, 1, 2, 3, 4, u8::MAX];
977+
create_record_batch(
978+
DataType::UInt8,
979+
UInt8Array::from(data8.clone()),
980+
ColumnType::Int32, //here we are missing ColumnType::Int16
981+
data8
982+
.into_iter()
983+
.map(|e| TableValue::Int16(e as i16))
984+
.collect::<Vec<_>>(),
985+
)?;
986+
987+
let data16 = vec![0, 1, 2, 3, 4, u16::MAX];
988+
create_record_batch(
989+
DataType::UInt16,
990+
UInt16Array::from(data16.clone()),
991+
ColumnType::Int32,
992+
data16
993+
.into_iter()
994+
.map(|e| TableValue::Int32(e as i32))
995+
.collect::<Vec<_>>(),
996+
)?;
997+
998+
let data32 = vec![0, 1, 2, 3, 4, u32::MAX];
999+
create_record_batch(
1000+
DataType::UInt32,
1001+
UInt32Array::from(data32.clone()),
1002+
ColumnType::Int64,
1003+
data32
1004+
.into_iter()
1005+
.map(|e| TableValue::Int64(e as i64))
1006+
.collect::<Vec<_>>(),
1007+
)?;
1008+
1009+
// TODO: uncomment u64::MAX when we will support u64 properly
1010+
let data64 = vec![0, 1, 2, 3, 4 /*, u64::MAX*/];
1011+
create_record_batch(
1012+
DataType::UInt64,
1013+
UInt64Array::from(data64.clone()),
1014+
ColumnType::Decimal(128, 0),
1015+
data64
1016+
.into_iter()
1017+
.map(|e| TableValue::Decimal128(Decimal128Value::new(e as i128, 0)))
1018+
.collect::<Vec<_>>(),
1019+
)
1020+
}
1021+
1022+
impl PartialEq for TableValue {
1023+
fn eq(&self, other: &Self) -> bool {
1024+
match (self, other) {
1025+
(TableValue::Null, TableValue::Null) => true,
1026+
(TableValue::String(a), TableValue::String(b)) => a == b,
1027+
(TableValue::Int16(a), TableValue::Int16(b)) => a == b,
1028+
(TableValue::Int32(a), TableValue::Int32(b)) => a == b,
1029+
(TableValue::Int64(a), TableValue::Int64(b)) => a == b,
1030+
(TableValue::Boolean(a), TableValue::Boolean(b)) => a == b,
1031+
(TableValue::Float32(a), TableValue::Float32(b)) => a == b,
1032+
(TableValue::Float64(a), TableValue::Float64(b)) => a == b,
1033+
(TableValue::List(a), TableValue::List(b)) => panic!("unsupported"),
1034+
(TableValue::Decimal128(a), TableValue::Decimal128(b)) => a == b,
1035+
(TableValue::Date(a), TableValue::Date(b)) => a == b,
1036+
(TableValue::Timestamp(a), TableValue::Timestamp(b)) => a == b,
1037+
(TableValue::Interval(a), TableValue::Interval(b)) => panic!("unsupported"),
1038+
_ => false,
1039+
}
1040+
}
1041+
}
8281042
}

0 commit comments

Comments
 (0)