|
19 | 19 |
|
20 | 20 | use std::fmt::Display; |
21 | 21 | use std::hash::Hash; |
| 22 | +use std::sync::Arc; |
22 | 23 |
|
23 | 24 | use crate::type_coercion::aggregates::NUMERICS; |
24 | | -use arrow::datatypes::{DataType, Decimal128Type, DecimalType, IntervalUnit, TimeUnit}; |
| 25 | +use arrow::datatypes::{ |
| 26 | + DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, DataType, |
| 27 | + Decimal128Type, DecimalType, Field, IntervalUnit, TimeUnit, |
| 28 | +}; |
25 | 29 | use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType}; |
26 | 30 | use datafusion_common::utils::ListCoercion; |
27 | 31 | use datafusion_common::{Result, internal_err, plan_err}; |
@@ -328,14 +332,23 @@ impl TypeSignature { |
328 | 332 | /// arguments that can be coerced to a particular class of types. |
329 | 333 | #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Hash)] |
330 | 334 | pub enum TypeSignatureClass { |
| 335 | + /// Timestamps, allowing arbitrary (or no) timezones |
331 | 336 | Timestamp, |
| 337 | + /// All time types |
332 | 338 | Time, |
| 339 | + /// All interval types |
333 | 340 | Interval, |
| 341 | + /// All duration types |
334 | 342 | Duration, |
| 343 | + /// A specific native type |
335 | 344 | Native(LogicalTypeRef), |
| 345 | + /// Signed and unsigned integers |
336 | 346 | Integer, |
| 347 | + /// All float types |
337 | 348 | Float, |
| 349 | + /// All decimal types, allowing arbitrary precision & scale |
338 | 350 | Decimal, |
| 351 | + /// Integers, floats and decimals |
339 | 352 | Numeric, |
340 | 353 | /// Encompasses both the native Binary/LargeBinary types as well as arbitrarily sized FixedSizeBinary types |
341 | 354 | Binary, |
@@ -888,8 +901,56 @@ fn get_data_types(native_type: &NativeType) -> Vec<DataType> { |
888 | 901 | NativeType::String => { |
889 | 902 | vec![DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View] |
890 | 903 | } |
891 | | - // TODO: support other native types |
892 | | - _ => vec![], |
| 904 | + NativeType::Decimal(precision, scale) => { |
| 905 | + // We assume incoming NativeType is valid already, in terms of precision & scale |
| 906 | + let mut types = vec![DataType::Decimal256(*precision, *scale)]; |
| 907 | + if *precision <= DECIMAL32_MAX_PRECISION { |
| 908 | + types.push(DataType::Decimal32(*precision, *scale)); |
| 909 | + } |
| 910 | + if *precision <= DECIMAL64_MAX_PRECISION { |
| 911 | + types.push(DataType::Decimal64(*precision, *scale)); |
| 912 | + } |
| 913 | + if *precision <= DECIMAL128_MAX_PRECISION { |
| 914 | + types.push(DataType::Decimal128(*precision, *scale)); |
| 915 | + } |
| 916 | + types |
| 917 | + } |
| 918 | + NativeType::Timestamp(time_unit, timezone) => { |
| 919 | + vec![DataType::Timestamp(*time_unit, timezone.to_owned())] |
| 920 | + } |
| 921 | + NativeType::Time(TimeUnit::Second) => vec![DataType::Time32(TimeUnit::Second)], |
| 922 | + NativeType::Time(TimeUnit::Millisecond) => { |
| 923 | + vec![DataType::Time32(TimeUnit::Millisecond)] |
| 924 | + } |
| 925 | + NativeType::Time(TimeUnit::Microsecond) => { |
| 926 | + vec![DataType::Time64(TimeUnit::Microsecond)] |
| 927 | + } |
| 928 | + NativeType::Time(TimeUnit::Nanosecond) => { |
| 929 | + vec![DataType::Time64(TimeUnit::Nanosecond)] |
| 930 | + } |
| 931 | + NativeType::Duration(time_unit) => vec![DataType::Duration(*time_unit)], |
| 932 | + NativeType::Interval(interval_unit) => vec![DataType::Interval(*interval_unit)], |
| 933 | + NativeType::FixedSizeBinary(size) => vec![DataType::FixedSizeBinary(*size)], |
| 934 | + NativeType::FixedSizeList(logical_field, size) => { |
| 935 | + get_data_types(logical_field.logical_type.native()) |
| 936 | + .iter() |
| 937 | + .map(|child_dt| { |
| 938 | + let field = Field::new( |
| 939 | + logical_field.name.clone(), |
| 940 | + child_dt.clone(), |
| 941 | + logical_field.nullable, |
| 942 | + ); |
| 943 | + DataType::FixedSizeList(Arc::new(field), *size) |
| 944 | + }) |
| 945 | + .collect() |
| 946 | + } |
| 947 | + // TODO: implement for nested types |
| 948 | + NativeType::List(_) |
| 949 | + | NativeType::Struct(_) |
| 950 | + | NativeType::Union(_) |
| 951 | + | NativeType::Map(_) => { |
| 952 | + vec![] |
| 953 | + } |
893 | 954 | } |
894 | 955 | } |
895 | 956 |
|
|
0 commit comments