Skip to content

Commit 46d6875

Browse files
committed
feat: implement extension type support WIP
1 parent 7e0046b commit 46d6875

File tree

6 files changed

+193
-118
lines changed

6 files changed

+193
-118
lines changed

Cargo.lock

Lines changed: 74 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

arrow-pg/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ rust-version.workspace = true
1616
default = ["arrow", "geo"]
1717
arrow = ["dep:arrow"]
1818
datafusion = ["dep:datafusion"]
19-
geo = ["postgres-types/with-geo-types-0_7", "dep:geoarrow-schema"]
19+
geo = ["postgres-types/with-geo-types-0_7", "dep:geoarrow", "dep:geoarrow-schema"]
2020

2121
[dependencies]
2222
arrow = { workspace = true, optional = true }
2323
arrow-schema = { workspace = true}
24+
geoarrow = { version = "0.7", optional = true }
2425
geoarrow-schema = { version = "0.7", optional = true }
2526
bytes.workspace = true
2627
chrono.workspace = true

arrow-pg/src/datatypes.rs

Lines changed: 99 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::sync::Arc;
22

33
#[cfg(not(feature = "datafusion"))]
44
use arrow::{datatypes::*, record_batch::RecordBatch};
5+
#[cfg(feature = "geo")]
56
use arrow_schema::extension::ExtensionType;
67
#[cfg(feature = "datafusion")]
78
use datafusion::arrow::{datatypes::*, record_batch::RecordBatch};
@@ -19,115 +20,114 @@ use crate::row_encoder::RowEncoder;
1920
#[cfg(feature = "datafusion")]
2021
pub mod df;
2122

22-
pub fn into_pg_type(field: &Arc<Field>) -> PgWireResult<Type> {
23-
let arrow_type = field.data_type();
24-
25-
match field.extension_type_name() {
26-
// As of arrow 56, there are additional extension logical type that is
27-
// defined using field metadata, for instance, json or geo.
28-
#[cfg(feature = "geo")]
29-
Some(geoarrow_schema::PointType::NAME) => Ok(Type::POINT),
30-
31-
_ => Ok(match arrow_type {
32-
DataType::Null => Type::UNKNOWN,
33-
DataType::Boolean => Type::BOOL,
34-
DataType::Int8 | DataType::UInt8 => Type::CHAR,
35-
DataType::Int16 | DataType::UInt16 => Type::INT2,
36-
DataType::Int32 | DataType::UInt32 => Type::INT4,
37-
DataType::Int64 | DataType::UInt64 => Type::INT8,
23+
pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
24+
let datatype = match arrow_type {
25+
DataType::Null => Type::UNKNOWN,
26+
DataType::Boolean => Type::BOOL,
27+
DataType::Int8 | DataType::UInt8 => Type::CHAR,
28+
DataType::Int16 | DataType::UInt16 => Type::INT2,
29+
DataType::Int32 | DataType::UInt32 => Type::INT4,
30+
DataType::Int64 | DataType::UInt64 => Type::INT8,
31+
DataType::Timestamp(_, tz) => {
32+
if tz.is_some() {
33+
Type::TIMESTAMPTZ
34+
} else {
35+
Type::TIMESTAMP
36+
}
37+
}
38+
DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
39+
DataType::Date32 | DataType::Date64 => Type::DATE,
40+
DataType::Interval(_) => Type::INTERVAL,
41+
DataType::Binary
42+
| DataType::FixedSizeBinary(_)
43+
| DataType::LargeBinary
44+
| DataType::BinaryView => Type::BYTEA,
45+
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
46+
DataType::Float64 => Type::FLOAT8,
47+
DataType::Decimal128(_, _) => Type::NUMERIC,
48+
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT,
49+
DataType::List(field)
50+
| DataType::FixedSizeList(field, _)
51+
| DataType::LargeList(field)
52+
| DataType::ListView(field)
53+
| DataType::LargeListView(field) => match field.data_type() {
54+
DataType::Boolean => Type::BOOL_ARRAY,
55+
DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY,
56+
DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY,
57+
DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY,
58+
DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY,
3859
DataType::Timestamp(_, tz) => {
3960
if tz.is_some() {
40-
Type::TIMESTAMPTZ
61+
Type::TIMESTAMPTZ_ARRAY
4162
} else {
42-
Type::TIMESTAMP
63+
Type::TIMESTAMP_ARRAY
4364
}
4465
}
45-
DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
46-
DataType::Date32 | DataType::Date64 => Type::DATE,
47-
DataType::Interval(_) => Type::INTERVAL,
48-
DataType::Binary
49-
| DataType::FixedSizeBinary(_)
66+
DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY,
67+
DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY,
68+
DataType::Interval(_) => Type::INTERVAL_ARRAY,
69+
DataType::FixedSizeBinary(_)
70+
| DataType::Binary
5071
| DataType::LargeBinary
51-
| DataType::BinaryView => Type::BYTEA,
52-
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
53-
DataType::Float64 => Type::FLOAT8,
54-
DataType::Decimal128(_, _) => Type::NUMERIC,
55-
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT,
56-
DataType::List(field)
57-
| DataType::FixedSizeList(field, _)
58-
| DataType::LargeList(field)
59-
| DataType::ListView(field)
60-
| DataType::LargeListView(field) => match field.data_type() {
61-
DataType::Boolean => Type::BOOL_ARRAY,
62-
DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY,
63-
DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY,
64-
DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY,
65-
DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY,
66-
DataType::Timestamp(_, tz) => {
67-
if tz.is_some() {
68-
Type::TIMESTAMPTZ_ARRAY
69-
} else {
70-
Type::TIMESTAMP_ARRAY
71-
}
72-
}
73-
DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY,
74-
DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY,
75-
DataType::Interval(_) => Type::INTERVAL_ARRAY,
76-
DataType::FixedSizeBinary(_)
77-
| DataType::Binary
78-
| DataType::LargeBinary
79-
| DataType::BinaryView => Type::BYTEA_ARRAY,
80-
DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
81-
DataType::Float64 => Type::FLOAT8_ARRAY,
82-
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY,
83-
DataType::Struct(_) => Type::new(
84-
Type::RECORD_ARRAY.name().into(),
85-
Type::RECORD_ARRAY.oid(),
86-
Kind::Array(into_pg_type(field)?),
87-
Type::RECORD_ARRAY.schema().into(),
88-
),
89-
list_type => {
90-
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
91-
"ERROR".to_owned(),
92-
"XX000".to_owned(),
93-
format!("Unsupported List Datatype {list_type}"),
94-
))));
95-
}
96-
},
97-
DataType::Dictionary(_, value_type) => {
98-
let field = Arc::new(Field::new(
99-
Field::LIST_FIELD_DEFAULT_NAME,
100-
*value_type.clone(),
101-
true,
102-
));
103-
into_pg_type(&field)?
104-
}
105-
DataType::Struct(fields) => {
106-
let name: String = fields
107-
.iter()
108-
.map(|x| x.name().clone())
109-
.reduce(|a, b| a + ", " + &b)
110-
.map(|x| format!("({x})"))
111-
.unwrap_or("()".to_string());
112-
let kind = Kind::Composite(
113-
fields
114-
.iter()
115-
.map(|x| {
116-
into_pg_type(x)
117-
.map(|_type| postgres_types::Field::new(x.name().clone(), _type))
118-
})
119-
.collect::<Result<Vec<_>, PgWireError>>()?,
120-
);
121-
Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into())
122-
}
123-
_ => {
72+
| DataType::BinaryView => Type::BYTEA_ARRAY,
73+
DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
74+
DataType::Float64 => Type::FLOAT8_ARRAY,
75+
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY,
76+
DataType::Struct(_) => Type::new(
77+
Type::RECORD_ARRAY.name().into(),
78+
Type::RECORD_ARRAY.oid(),
79+
Kind::Array(field_into_pg_type(field)?),
80+
Type::RECORD_ARRAY.schema().into(),
81+
),
82+
list_type => {
12483
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
12584
"ERROR".to_owned(),
12685
"XX000".to_owned(),
127-
format!("Unsupported Datatype {arrow_type}"),
86+
format!("Unsupported List Datatype {list_type}"),
12887
))));
12988
}
130-
}),
89+
},
90+
DataType::Dictionary(_, value_type) => into_pg_type(value_type.as_ref())?,
91+
DataType::Struct(fields) => {
92+
let name: String = fields
93+
.iter()
94+
.map(|x| x.name().clone())
95+
.reduce(|a, b| a + ", " + &b)
96+
.map(|x| format!("({x})"))
97+
.unwrap_or("()".to_string());
98+
let kind = Kind::Composite(
99+
fields
100+
.iter()
101+
.map(|x| {
102+
field_into_pg_type(x)
103+
.map(|_type| postgres_types::Field::new(x.name().clone(), _type))
104+
})
105+
.collect::<Result<Vec<_>, PgWireError>>()?,
106+
);
107+
Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into())
108+
}
109+
_ => {
110+
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
111+
"ERROR".to_owned(),
112+
"XX000".to_owned(),
113+
format!("Unsupported Datatype {arrow_type}"),
114+
))));
115+
}
116+
};
117+
118+
Ok(datatype)
119+
}
120+
121+
pub fn field_into_pg_type(field: &Arc<Field>) -> PgWireResult<Type> {
122+
let arrow_type = field.data_type();
123+
124+
match field.extension_type_name() {
125+
// As of arrow 56, there are additional extension logical type that is
126+
// defined using field metadata, for instance, json or geo.
127+
#[cfg(feature = "geo")]
128+
Some(geoarrow_schema::PointType::NAME) => Ok(Type::POINT),
129+
130+
_ => into_pg_type(arrow_type),
131131
}
132132
}
133133

@@ -142,7 +142,7 @@ pub fn arrow_schema_to_pg_fields(
142142
.iter()
143143
.enumerate()
144144
.map(|(idx, f)| {
145-
let pg_type = into_pg_type(f)?;
145+
let pg_type = field_into_pg_type(f)?;
146146
let mut field_info =
147147
FieldInfo::new(f.name().into(), None, None, pg_type, format.format_for(idx));
148148
if let Some(data_format_options) = &data_format_options {

arrow-pg/src/datatypes/df.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::iter;
22
use std::sync::Arc;
33

44
use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Timelike};
5-
use datafusion::arrow::datatypes::{DataType, Date32Type, Field, TimeUnit};
5+
use datafusion::arrow::datatypes::{DataType, Date32Type, TimeUnit};
66
use datafusion::arrow::record_batch::RecordBatch;
77
use datafusion::common::ParamValues;
88
use datafusion::prelude::*;
@@ -70,7 +70,7 @@ where
7070
if let Some(ty) = pg_type_hint {
7171
Ok(ty.clone())
7272
} else if let Some(infer_type) = inferenced_type {
73-
into_pg_type(&Arc::new(Field::new("item", infer_type.clone(), true)))
73+
into_pg_type(infer_type)
7474
} else {
7575
Ok(Type::UNKNOWN)
7676
}

0 commit comments

Comments
 (0)