Skip to content

Commit 97853ee

Browse files
committed
chore(cubesql): Support to_regtype and regtype casts
1 parent ebad0df commit 97853ee

14 files changed

+1416
-282
lines changed

rust/cubesql/cubesql/src/compile/engine/information_schema/postgres/pg_type.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ impl PgCatalogTypeProvider {
190190
builder.add_type(&PgType {
191191
oid: table.record_oid,
192192
typname: table.name.as_str(),
193+
regtype: table.name.as_str(),
193194
typnamespace: 2200,
194195
typowner: 10,
195196
typlen: -1,
@@ -215,6 +216,7 @@ impl PgCatalogTypeProvider {
215216
builder.add_type(&PgType {
216217
oid: table.array_handler_oid,
217218
typname: format!("_{}", table.name).as_str(),
219+
regtype: format!("{}[]", table.name).as_str(),
218220
typnamespace: 2200,
219221
typowner: 10,
220222
typlen: -1,

rust/cubesql/cubesql/src/compile/engine/udf.rs

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,15 +1649,15 @@ pub fn create_format_type_udf() -> ScalarUDF {
16491649
(Some(oid), typemod) => Some(match PgTypeId::from_oid(oid) {
16501650
Some(type_id) => {
16511651
let typemod_str = || match type_id {
1652-
PgTypeId::BPCHAR | PgTypeId::VARCHAR => match typemod {
1652+
PgTypeId::BPCHAR
1653+
| PgTypeId::VARCHAR
1654+
| PgTypeId::ARRAYBPCHAR
1655+
| PgTypeId::ARRAYVARCHAR => match typemod {
16531656
Some(typemod) if typemod >= 5 => format!("({})", typemod - 4),
16541657
_ => "".to_string(),
16551658
},
1656-
PgTypeId::NUMERIC => match typemod {
1659+
PgTypeId::NUMERIC | PgTypeId::ARRAYNUMERIC => match typemod {
16571660
Some(typemod) if typemod >= 4 => format!("(0,{})", typemod - 4),
1658-
Some(typemod) if typemod >= 0 => {
1659-
format!("(65535,{})", 65532 + typemod)
1660-
}
16611661
_ => "".to_string(),
16621662
},
16631663
_ => match typemod {
@@ -1678,20 +1678,30 @@ pub fn create_format_type_udf() -> ScalarUDF {
16781678
PgTypeId::OID => format!("oid{}", typemod_str()),
16791679
PgTypeId::TID => format!("tid{}", typemod_str()),
16801680
PgTypeId::PGCLASS => format!("pg_class{}", typemod_str()),
1681+
PgTypeId::ARRAYPGCLASS => format!("pg_class{}[]", typemod_str()),
16811682
PgTypeId::FLOAT4 => "real".to_string(),
16821683
PgTypeId::FLOAT8 => "double precision".to_string(),
16831684
PgTypeId::MONEY => format!("money{}", typemod_str()),
1685+
PgTypeId::ARRAYMONEY => format!("money{}[]", typemod_str()),
16841686
PgTypeId::INET => format!("inet{}", typemod_str()),
16851687
PgTypeId::ARRAYBOOL => "boolean[]".to_string(),
16861688
PgTypeId::ARRAYBYTEA => format!("bytea{}[]", typemod_str()),
1689+
PgTypeId::ARRAYNAME => format!("name{}[]", typemod_str()),
16871690
PgTypeId::ARRAYINT2 => "smallint[]".to_string(),
16881691
PgTypeId::ARRAYINT4 => "integer[]".to_string(),
16891692
PgTypeId::ARRAYTEXT => format!("text{}[]", typemod_str()),
1693+
PgTypeId::ARRAYTID => format!("tid{}[]", typemod_str()),
1694+
PgTypeId::ARRAYBPCHAR => format!("character{}[]", typemod_str()),
1695+
PgTypeId::ARRAYVARCHAR => {
1696+
format!("character varying{}[]", typemod_str())
1697+
}
16901698
PgTypeId::ARRAYINT8 => "bigint[]".to_string(),
16911699
PgTypeId::ARRAYFLOAT4 => "real[]".to_string(),
16921700
PgTypeId::ARRAYFLOAT8 => "double precision[]".to_string(),
1701+
PgTypeId::ARRAYOID => format!("oid{}[]", typemod_str()),
16931702
PgTypeId::ACLITEM => format!("aclitem{}", typemod_str()),
16941703
PgTypeId::ARRAYACLITEM => format!("aclitem{}[]", typemod_str()),
1704+
PgTypeId::ARRAYINET => format!("inet{}[]", typemod_str()),
16951705
PgTypeId::BPCHAR => match typemod {
16961706
Some(typemod) if typemod < 0 => "bpchar".to_string(),
16971707
_ => format!("character{}", typemod_str()),
@@ -1702,38 +1712,83 @@ pub fn create_format_type_udf() -> ScalarUDF {
17021712
PgTypeId::TIMESTAMP => {
17031713
format!("timestamp{} without time zone", typemod_str())
17041714
}
1715+
PgTypeId::ARRAYTIMESTAMP => {
1716+
format!("timestamp{} without time zone[]", typemod_str())
1717+
}
1718+
PgTypeId::ARRAYDATE => format!("date{}[]", typemod_str()),
1719+
PgTypeId::ARRAYTIME => {
1720+
format!("time{} without time zone[]", typemod_str())
1721+
}
17051722
PgTypeId::TIMESTAMPTZ => {
17061723
format!("timestamp{} with time zone", typemod_str())
17071724
}
1725+
PgTypeId::ARRAYTIMESTAMPTZ => {
1726+
format!("timestamp{} with time zone[]", typemod_str())
1727+
}
17081728
PgTypeId::INTERVAL => match typemod {
17091729
Some(typemod) if typemod >= 0 => "-".to_string(),
17101730
_ => "interval".to_string(),
17111731
},
1732+
PgTypeId::ARRAYINTERVAL => match typemod {
1733+
Some(typemod) if typemod >= 0 => "-".to_string(),
1734+
_ => "interval[]".to_string(),
1735+
},
1736+
PgTypeId::ARRAYNUMERIC => format!("numeric{}[]", typemod_str()),
17121737
PgTypeId::TIMETZ => format!("time{} with time zone", typemod_str()),
1738+
PgTypeId::ARRAYTIMETZ => {
1739+
format!("time{} with time zone[]", typemod_str())
1740+
}
17131741
PgTypeId::NUMERIC => format!("numeric{}", typemod_str()),
17141742
PgTypeId::RECORD => format!("record{}", typemod_str()),
17151743
PgTypeId::ANYARRAY => format!("anyarray{}", typemod_str()),
17161744
PgTypeId::ANYELEMENT => format!("anyelement{}", typemod_str()),
1745+
PgTypeId::ARRAYRECORD => format!("record{}[]", typemod_str()),
17171746
PgTypeId::PGLSN => format!("pg_lsn{}", typemod_str()),
1747+
PgTypeId::ARRAYPGLSN => format!("pg_lsn{}[]", typemod_str()),
17181748
PgTypeId::ANYENUM => format!("anyenum{}", typemod_str()),
17191749
PgTypeId::ANYRANGE => format!("anyrange{}", typemod_str()),
17201750
PgTypeId::INT4RANGE => format!("int4range{}", typemod_str()),
1751+
PgTypeId::ARRAYINT4RANGE => format!("int4range{}[]", typemod_str()),
17211752
PgTypeId::NUMRANGE => format!("numrange{}", typemod_str()),
1753+
PgTypeId::ARRAYNUMRANGE => format!("numrange{}[]", typemod_str()),
17221754
PgTypeId::TSRANGE => format!("tsrange{}", typemod_str()),
1755+
PgTypeId::ARRAYTSRANGE => format!("tsrange{}[]", typemod_str()),
17231756
PgTypeId::TSTZRANGE => format!("tstzrange{}", typemod_str()),
1757+
PgTypeId::ARRAYTSTZRANGE => format!("tstzrange{}[]", typemod_str()),
17241758
PgTypeId::DATERANGE => format!("daterange{}", typemod_str()),
1759+
PgTypeId::ARRAYDATERANGE => format!("daterange{}[]", typemod_str()),
17251760
PgTypeId::INT8RANGE => format!("int8range{}", typemod_str()),
1761+
PgTypeId::ARRAYINT8RANGE => format!("int8range{}[]", typemod_str()),
17261762
PgTypeId::INT4MULTIRANGE => format!("int4multirange{}", typemod_str()),
17271763
PgTypeId::NUMMULTIRANGE => format!("nummultirange{}", typemod_str()),
17281764
PgTypeId::TSMULTIRANGE => format!("tsmultirange{}", typemod_str()),
17291765
PgTypeId::DATEMULTIRANGE => format!("datemultirange{}", typemod_str()),
17301766
PgTypeId::INT8MULTIRANGE => format!("int8multirange{}", typemod_str()),
1731-
PgTypeId::CHARACTERDATA => {
1732-
format!("information_schema.character_data{}", typemod_str())
1767+
PgTypeId::ARRAYINT4MULTIRANGE => {
1768+
format!("int4multirange{}[]", typemod_str())
1769+
}
1770+
PgTypeId::ARRAYNUMMULTIRANGE => {
1771+
format!("nummultirange{}[]", typemod_str())
1772+
}
1773+
PgTypeId::ARRAYTSMULTIRANGE => {
1774+
format!("tsmultirange{}[]", typemod_str())
1775+
}
1776+
PgTypeId::ARRAYDATEMULTIRANGE => {
1777+
format!("datemultirange{}[]", typemod_str())
1778+
}
1779+
PgTypeId::ARRAYINT8MULTIRANGE => {
1780+
format!("int8multirange{}[]", typemod_str())
1781+
}
1782+
PgTypeId::ARRAYPGCONSTRAINT => {
1783+
format!("pg_constraint{}[]", typemod_str())
17331784
}
17341785
PgTypeId::PGCONSTRAINT => format!("pg_constraint{}", typemod_str()),
1735-
PgTypeId::PGNAMESPACE => {
1736-
format!("pg_namespace{}", typemod_str())
1786+
PgTypeId::ARRAYPGNAMESPACE => {
1787+
format!("pg_namespace{}[]", typemod_str())
1788+
}
1789+
PgTypeId::PGNAMESPACE => format!("pg_namespace{}", typemod_str()),
1790+
PgTypeId::CHARACTERDATA => {
1791+
format!("information_schema.character_data{}", typemod_str())
17371792
}
17381793
PgTypeId::SQLIDENTIFIER => {
17391794
format!("information_schema.sql_identifier{}", typemod_str())
@@ -3183,3 +3238,36 @@ pub fn create_mod_udf() -> ScalarUDF {
31833238
&fun,
31843239
)
31853240
}
3241+
3242+
pub fn create_to_regtype_udf() -> ScalarUDF {
3243+
let fun = make_scalar_function(move |args: &[ArrayRef]| {
3244+
assert!(args.len() == 1);
3245+
3246+
let regtype_arr = downcast_string_arg!(args[0], "regtype", i32);
3247+
3248+
let pg_types = PgType::get_all();
3249+
3250+
let result = regtype_arr
3251+
.iter()
3252+
.map(|regtype| match regtype {
3253+
Some(regtype) => pg_types
3254+
.iter()
3255+
.find(|typ| typ.typname == regtype || typ.regtype == regtype)
3256+
.map(|typ| typ.oid as i32),
3257+
None => None,
3258+
})
3259+
.collect::<PrimitiveArray<Int32Type>>();
3260+
3261+
Ok(Arc::new(result) as ArrayRef)
3262+
});
3263+
3264+
// TODO: `to_regtype` should return regtype but we use `oid` since it's used for comparison with oids
3265+
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Int32)));
3266+
3267+
ScalarUDF::new(
3268+
"to_regtype",
3269+
&Signature::exact(vec![DataType::Utf8], Volatility::Immutable),
3270+
&return_type,
3271+
&fun,
3272+
)
3273+
}

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ use self::{
6464
create_position_udf, create_quarter_udf, create_quote_ident_udf,
6565
create_regexp_substr_udf, create_second_udf, create_session_user_udf, create_sha1_udf,
6666
create_str_to_date_udf, create_time_format_udf, create_timediff_udf,
67-
create_to_char_udf, create_to_date_udf, create_ucase_udf, create_unnest_udtf,
68-
create_user_udf, create_version_udf, create_year_udf,
67+
create_to_char_udf, create_to_date_udf, create_to_regtype_udf, create_ucase_udf,
68+
create_unnest_udtf, create_user_udf, create_version_udf, create_year_udf,
6969
},
7070
},
7171
parser::parse_sql_to_statement,
@@ -1183,6 +1183,7 @@ WHERE `TABLE_SCHEMA` = '{}'",
11831183
ctx.register_udf(create_array_to_string_udf());
11841184
ctx.register_udf(create_charindex_udf());
11851185
ctx.register_udf(create_mod_udf());
1186+
ctx.register_udf(create_to_regtype_udf());
11861187

11871188
// udaf
11881189
ctx.register_udaf(create_measure_udaf());
@@ -7883,6 +7884,26 @@ ORDER BY \"COUNT(count)\" DESC"
78837884
Ok(())
78847885
}
78857886

7887+
#[tokio::test]
7888+
async fn test_pg_to_regtype_pid() -> Result<(), CubeError> {
7889+
insta::assert_snapshot!(
7890+
"pg_to_regtype",
7891+
execute_query(
7892+
"select
7893+
to_regtype('bool') b,
7894+
to_regtype('name') n,
7895+
to_regtype('_int4') ai,
7896+
to_regtype('unknown') u
7897+
;"
7898+
.to_string(),
7899+
DatabaseProtocol::PostgreSQL
7900+
)
7901+
.await?
7902+
);
7903+
7904+
Ok(())
7905+
}
7906+
78867907
#[tokio::test]
78877908
async fn test_date_part_quarter() -> Result<(), CubeError> {
78887909
insta::assert_snapshot!(
@@ -8526,6 +8547,33 @@ ORDER BY \"COUNT(count)\" DESC"
85268547
Ok(())
85278548
}
85288549

8550+
#[tokio::test]
8551+
async fn test_sqlalchemy_regtype() -> Result<(), CubeError> {
8552+
insta::assert_snapshot!(
8553+
"sqlalchemy_regtype",
8554+
execute_query(
8555+
"SELECT
8556+
typname AS name,
8557+
oid,
8558+
typarray AS array_oid,
8559+
CAST(CAST(oid AS regtype) AS TEXT) AS regtype,
8560+
typdelim AS delimiter
8561+
FROM
8562+
pg_type AS t
8563+
WHERE
8564+
t.oid = to_regtype('boolean')
8565+
ORDER BY
8566+
t.oid
8567+
;"
8568+
.to_string(),
8569+
DatabaseProtocol::PostgreSQL
8570+
)
8571+
.await?
8572+
);
8573+
8574+
Ok(())
8575+
}
8576+
85298577
#[tokio::test]
85308578
async fn pgcli_queries() -> Result<(), CubeError> {
85318579
init_logger();

0 commit comments

Comments
 (0)