Skip to content

Commit a35b7fe

Browse files
JasonShinCopilot
andauthored
lateral tables - basic select, batch insert (#236)
* support batch insert * Support batch insert * add error print * fmt * default to Boolean if placeholder type cannot be found * fix tests * rename tests * fmt * fix postgres lateral test * fix single placeholder boolean * fix * Update tests/postgres_lateral_tables.rs Co-authored-by: Copilot <[email protected]> * Update src/ts_generator/sql_parser/translate_query.rs Co-authored-by: Copilot <[email protected]> * Update src/ts_generator/sql_parser/expressions/translate_expr.rs Co-authored-by: Copilot <[email protected]> --------- Co-authored-by: Copilot <[email protected]>
1 parent d82457d commit a35b7fe

File tree

10 files changed

+492
-92
lines changed

10 files changed

+492
-92
lines changed

src/core/connection.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::common::lazy::CONFIG;
2+
use crate::common::types::DatabaseType;
23
use crate::common::SQL;
34
use crate::core::mysql::prepare as mysql_explain;
45
use crate::core::postgres::prepare as postgres_explain;
@@ -34,6 +35,14 @@ impl DBConn {
3435

3536
Ok((explain_failed, ts_query))
3637
}
38+
39+
/// Get the database type for this connection
40+
pub fn get_db_type(&self) -> DatabaseType {
41+
match self {
42+
DBConn::MySQLPooledConn(_) => DatabaseType::Mysql,
43+
DBConn::PostgresConn(_) => DatabaseType::Postgres,
44+
}
45+
}
3746
}
3847

3948
pub struct DBConnections<'a> {

src/ts_generator/generator.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@ use crate::ts_generator::annotations::extract_result_annotations;
1111
use crate::ts_generator::sql_parser::translate_stmt::translate_stmt;
1212
use crate::ts_generator::types::ts_query::TsQuery;
1313

14+
use crate::common::types::DatabaseType;
1415
use color_eyre::eyre::eyre;
1516
use color_eyre::eyre::Result;
1617
use convert_case::{Case, Casing};
1718
use regex::Regex;
18-
use sqlparser::{dialect::GenericDialect, parser::Parser};
19+
use sqlparser::{
20+
dialect::{Dialect, MySqlDialect, PostgreSqlDialect},
21+
parser::Parser,
22+
};
1923

2024
use super::errors::TsGeneratorError;
2125

@@ -117,9 +121,13 @@ pub fn clear_single_ts_file_if_exists() -> Result<()> {
117121
}
118122

119123
pub async fn generate_ts_interface(sql: &SQL, db_conn: &DBConn) -> Result<TsQuery> {
120-
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
124+
// Use the appropriate SQL dialect based on the database type
125+
let dialect: Box<dyn Dialect> = match db_conn.get_db_type() {
126+
DatabaseType::Postgres => Box::new(PostgreSqlDialect {}),
127+
DatabaseType::Mysql => Box::new(MySqlDialect {}),
128+
};
121129

122-
let sql_ast = Parser::parse_sql(&dialect, &sql.query)?;
130+
let sql_ast = Parser::parse_sql(&*dialect, &sql.query)?;
123131
let mut ts_query = TsQuery::new(get_query_name(sql)?);
124132

125133
let annotated_result_types = extract_result_annotations(sql.query.as_str());

src/ts_generator/sql_parser/expressions/translate_expr.rs

Lines changed: 101 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::functions::{is_date_function, is_numeric_function, is_type_polymorphic_function};
22
use crate::common::lazy::DB_SCHEMA;
3-
use crate::common::logger::warning;
3+
use crate::common::logger::{error, warning};
44
use crate::core::connection::DBConn;
55
use crate::ts_generator::errors::TsGeneratorError;
66
use crate::ts_generator::sql_parser::expressions::translate_data_type::translate_value;
@@ -180,20 +180,46 @@ pub async fn translate_expr(
180180
Expr::Identifier(ident) => {
181181
let column_name = DisplayIndent(ident).to_string();
182182
let table_name = single_table_name.expect("Missing table name for identifier");
183+
184+
// First check if this is a table-valued function column
185+
if let Some(tvf_columns) = ts_query.table_valued_function_columns.get(table_name) {
186+
if let Some(ts_type) = tvf_columns.get(&column_name) {
187+
let field_name = alias.unwrap_or(column_name.as_str());
188+
ts_query.insert_result(
189+
Some(field_name),
190+
&[ts_type.to_owned()],
191+
is_selection,
192+
false, // Table-valued function columns are not nullable by default
193+
expr_for_logging,
194+
)?;
195+
return Ok(());
196+
}
197+
}
198+
199+
// Fall back to database schema
183200
let table_details = &DB_SCHEMA.lock().await.fetch_table(&vec![table_name], db_conn).await;
184201

185-
// TODO: We can also memoize this method
186202
if let Some(table_details) = table_details {
187-
let field = table_details.get(&column_name).unwrap();
188-
189-
let field_name = alias.unwrap_or(column_name.as_str());
190-
ts_query.insert_result(
191-
Some(field_name),
192-
&[field.field_type.to_owned()],
193-
is_selection,
194-
field.is_nullable,
195-
expr_for_logging,
196-
)?
203+
if let Some(field) = table_details.get(&column_name) {
204+
let field_name = alias.unwrap_or(column_name.as_str());
205+
ts_query.insert_result(
206+
Some(field_name),
207+
&[field.field_type.to_owned()],
208+
is_selection,
209+
field.is_nullable,
210+
expr_for_logging,
211+
)?
212+
} else {
213+
error!(
214+
"Column '{}' not found in table '{}'. If '{}' is a table-valued function, verify that the column is defined in its alias. Otherwise, the column may not exist in the table.",
215+
column_name, table_name, table_name
216+
);
217+
}
218+
} else {
219+
error!(
220+
"Table '{}' not found in schema. This may be a table-valued function.",
221+
table_name
222+
);
197223
}
198224
Ok(())
199225
}
@@ -203,32 +229,67 @@ pub async fn translate_expr(
203229

204230
let table_name = translate_table_from_expr(table_with_joins, expr)?;
205231

232+
// First check if this is a table-valued function column
233+
if let Some(tvf_columns) = ts_query.table_valued_function_columns.get(&table_name) {
234+
if let Some(ts_type) = tvf_columns.get(&ident) {
235+
// if the select item is a compound identifier and does not has an alias, we should use `table_name.ident` as the key name
236+
let key_name = format!("{table_name}_{ident}");
237+
let key_name = &alias.unwrap_or_else(|| {
238+
warning!(
239+
"Missing an alias for a compound identifier, using {} as the key name. Prefer adding an alias for example: `{} AS {}`",
240+
key_name, expr, ident
241+
);
242+
key_name.as_str()
243+
});
244+
245+
ts_query.insert_result(
246+
Some(key_name),
247+
&[ts_type.to_owned()],
248+
is_selection,
249+
false, // Table-valued function columns are not nullable by default
250+
expr_for_logging,
251+
)?;
252+
return Ok(());
253+
}
254+
}
255+
256+
// Fall back to database schema
206257
let table_details = &DB_SCHEMA
207258
.lock()
208259
.await
209260
.fetch_table(&vec![table_name.as_str()], db_conn)
210261
.await;
211262

212263
if let Some(table_details) = table_details {
213-
let field = table_details.get(&ident).unwrap();
214-
215-
// if the select item is a compound identifier and does not has an alias, we should use `table_name.ident` as the key name
216-
let key_name = format!("{table_name}_{ident}");
217-
let key_name = &alias.unwrap_or_else(|| {
218-
warning!(
219-
"Missing an alias for a compound identifier, using {} as the key name. Prefer adding an alias for example: `{} AS {}`",
220-
key_name, expr, ident
221-
);
222-
key_name.as_str()
223-
});
224-
225-
ts_query.insert_result(
226-
Some(key_name),
227-
&[field.field_type.to_owned()],
228-
is_selection,
229-
field.is_nullable,
230-
expr_for_logging,
231-
)?;
264+
if let Some(field) = table_details.get(&ident) {
265+
// if the select item is a compound identifier and does not has an alias, we should use `table_name.ident` as the key name
266+
let key_name = format!("{table_name}_{ident}");
267+
let key_name = &alias.unwrap_or_else(|| {
268+
warning!(
269+
"Missing an alias for a compound identifier, using {} as the key name. Prefer adding an alias for example: `{} AS {}`",
270+
key_name, expr, ident
271+
);
272+
key_name.as_str()
273+
});
274+
275+
ts_query.insert_result(
276+
Some(key_name),
277+
&[field.field_type.to_owned()],
278+
is_selection,
279+
field.is_nullable,
280+
expr_for_logging,
281+
)?;
282+
} else {
283+
error!(
284+
"Column '{}' not found in table '{}' for compound identifier '{}.{}'. This may be a table-valued function.",
285+
ident, table_name, table_name, ident
286+
);
287+
}
288+
} else {
289+
error!(
290+
"Table '{}' not found in schema for compound identifier '{}.{}'. This may be a table-valued function.",
291+
table_name, table_name, ident
292+
);
232293
}
233294
}
234295
Ok(())
@@ -359,7 +420,15 @@ pub async fn translate_expr(
359420
if let Some(ts_field_type) = ts_field_type {
360421
return ts_query.insert_result(alias, &[ts_field_type], is_selection, false, expr_for_logging);
361422
}
362-
ts_query.insert_param(&TsFieldType::Boolean, &false, &Some(placeholder.to_string()))
423+
// For placeholders where we can't infer the type:
424+
// - If we're in a WHERE clause (is_selection is false AND we have a table context), infer as Boolean
425+
// - Otherwise, use Any for flexibility (e.g., for table-valued function arguments)
426+
let inferred_type = if !is_selection && single_table_name.is_some() {
427+
TsFieldType::Boolean
428+
} else {
429+
TsFieldType::Any
430+
};
431+
ts_query.insert_param(&inferred_type, &false, &Some(placeholder.to_string()))
363432
}
364433
Expr::JsonAccess { value: _, path: _ } => {
365434
ts_query.insert_result(alias, &[TsFieldType::Any], is_selection, false, expr_for_logging)?;

src/ts_generator/sql_parser/expressions/translate_table_with_joins.rs

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,56 @@ use crate::ts_generator::sql_parser::quoted_strings::*;
33
use color_eyre::eyre::Result;
44
use sqlparser::ast::{Assignment, AssignmentTarget, Expr, Join, SelectItem, TableFactor, TableWithJoins};
55

6+
/// Check if the given table name corresponds to a table-valued function alias
7+
/// by examining the table_with_joins to see if it's a TableFactor::Function
8+
#[allow(dead_code)]
9+
pub fn is_table_function(table_name: &str, table_with_joins: &[TableWithJoins]) -> bool {
10+
for twj in table_with_joins {
11+
if let TableFactor::Function { alias: Some(alias), .. } = &twj.relation {
12+
if DisplayTableAlias(alias).to_string() == table_name {
13+
return true;
14+
}
15+
}
16+
// Also check joins
17+
for join in &twj.joins {
18+
if let TableFactor::Function { alias: Some(alias), .. } = &join.relation {
19+
if DisplayTableAlias(alias).to_string() == table_name {
20+
return true;
21+
}
22+
}
23+
}
24+
}
25+
false
26+
}
27+
628
pub fn get_default_table(table_with_joins: &[TableWithJoins]) -> String {
729
table_with_joins
830
.first()
931
.and_then(|x| match &x.relation {
1032
TableFactor::Table {
1133
name,
12-
alias: _,
13-
args: _,
34+
alias,
35+
args,
1436
with_hints: _,
1537
version: _,
1638
partitions: _,
1739
with_ordinality: _,
1840
json_path: _,
1941
sample: _,
2042
index_hints: _,
21-
} => Some(DisplayObjectName(name).to_string()),
43+
} => {
44+
// If args is Some, it's a table-valued function (e.g., jsonb_to_recordset($1))
45+
// In that case, use the alias name if available
46+
if args.is_some() {
47+
alias.as_ref().map(|a| DisplayTableAlias(a).to_string())
48+
} else {
49+
Some(DisplayObjectName(name).to_string())
50+
}
51+
}
52+
TableFactor::Function { alias, .. } => {
53+
// For LATERAL functions, use the alias name as the table name
54+
alias.as_ref().map(|a| DisplayTableAlias(a).to_string())
55+
}
2256
_ => None,
2357
})
2458
.expect("The query does not have a default table, impossible to generate types")
@@ -41,7 +75,7 @@ pub fn find_table_name_from_identifier(
4175
TableFactor::Table {
4276
name,
4377
alias,
44-
args: _,
78+
args,
4579
with_hints: _,
4680
version: _,
4781
partitions: _,
@@ -50,11 +84,30 @@ pub fn find_table_name_from_identifier(
5084
sample: _,
5185
index_hints: _,
5286
} => {
53-
let alias = alias.clone().map(|alias| DisplayTableAlias(&alias).to_string());
54-
let name = DisplayObjectName(name).to_string();
55-
if Some(left.to_string()) == alias || left == name {
56-
// If the identifier matches the alias, then return the table name
57-
return Ok(name.to_owned());
87+
let alias_str = alias.clone().map(|alias| DisplayTableAlias(&alias).to_string());
88+
let name_str = DisplayObjectName(name).to_string();
89+
90+
// If this is a table-valued function (args is Some), use alias as the effective name
91+
if args.is_some() {
92+
if let Some(alias) = alias_str {
93+
if left == alias {
94+
return Ok(alias);
95+
}
96+
}
97+
} else {
98+
// Regular table
99+
if Some(left.to_string()) == alias_str || left == name_str {
100+
return Ok(name_str.to_owned());
101+
}
102+
}
103+
}
104+
TableFactor::Function { alias, .. } => {
105+
// For LATERAL functions, the alias is the effective table name
106+
if let Some(alias) = alias {
107+
let alias_name = DisplayTableAlias(alias).to_string();
108+
if left == alias_name {
109+
return Ok(alias_name);
110+
}
58111
}
59112
}
60113
_ => {
@@ -89,6 +142,15 @@ pub fn find_table_name_from_identifier(
89142
return Ok(name);
90143
}
91144
}
145+
TableFactor::Function { alias, .. } => {
146+
// For table-valued functions in joins, the alias is the effective table name
147+
if let Some(alias) = alias {
148+
let alias_name = DisplayTableAlias(alias).to_string();
149+
if left == alias_name {
150+
return Ok(alias_name);
151+
}
152+
}
153+
}
92154
_ => {
93155
return Err(TsGeneratorError::TableFactorWhileProcessingTableWithJoins(
94156
join.to_string(),

src/ts_generator/sql_parser/expressions/translate_wildcard_expr.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ pub fn get_all_table_names_from_select(select: &Select) -> Result<Vec<String>, T
2222
let name = DisplayObjectName(&name).to_string();
2323
Ok(name)
2424
}
25+
TableFactor::Function { .. } => {
26+
// Wildcard queries with table-valued functions are not supported
27+
// because we cannot query the database schema for function result types
28+
Err(TsGeneratorError::WildcardStatementUnsupportedTableExpr(
29+
select.to_string(),
30+
))
31+
}
2532
_ => Err(TsGeneratorError::WildcardStatementUnsupportedTableExpr(
2633
select.to_string(),
2734
)),

0 commit comments

Comments
 (0)