Skip to content

Commit 1178205

Browse files
committed
add support for edge cases
1 parent 43354d1 commit 1178205

File tree

1 file changed

+49
-13
lines changed

1 file changed

+49
-13
lines changed

src/parse_tables_extension.cpp

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
#include "duckdb/function/scalar_function.hpp"
88
#include "duckdb/main/extension_util.hpp"
99
#include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>
10-
1110
#include "duckdb/parser/parser.hpp"
1211
#include "duckdb/parser/statement/select_statement.hpp"
1312
#include "duckdb/parser/query_node/select_node.hpp"
1413
#include "duckdb/parser/tableref/basetableref.hpp"
1514
#include "duckdb/parser/tableref/joinref.hpp"
1615
#include "duckdb/parser/tableref/subqueryref.hpp"
16+
#include "duckdb/parser/statement/insert_statement.hpp"
1717

1818
namespace duckdb {
1919

@@ -60,29 +60,50 @@ static unique_ptr<GlobalTableFunctionState> MyInit(ClientContext &context,
6060
return make_uniq<ParseTablesState>();
6161
}
6262

63-
static void ExtractTablesFromQueryNode(const QueryNode &node, vector<TableRefResult> &results);
63+
static void ExtractTablesFromQueryNode(
64+
const duckdb::QueryNode &node,
65+
std::vector<TableRefResult> &results,
66+
const std::string &context = "from",
67+
const duckdb::CommonTableExpressionMap *cte_map = nullptr
68+
);
69+
70+
static void ExtractTablesFromRef(
71+
const duckdb::TableRef &ref,
72+
std::vector<TableRefResult> &results,
73+
const std::string &context = "from",
74+
bool is_top_level = false,
75+
const duckdb::CommonTableExpressionMap *cte_map = nullptr
76+
) {
77+
using namespace duckdb;
6478

65-
static void ExtractTablesFromRef(const TableRef &ref, vector<TableRefResult> &results, const string &context = "from", bool is_top_level = false) {
6679
switch (ref.type) {
6780
case TableReferenceType::BASE_TABLE: {
6881
auto &base = (BaseTableRef &)ref;
82+
std::string context_label = context;
83+
84+
if (cte_map && cte_map->map.find(base.table_name) != cte_map->map.end()) {
85+
context_label = "from_cte";
86+
} else if (is_top_level) {
87+
context_label = "from";
88+
}
89+
6990
results.push_back(TableRefResult{
7091
base.schema_name.empty() ? "main" : base.schema_name,
7192
base.table_name,
72-
is_top_level ? "from" : context
93+
context_label
7394
});
7495
break;
7596
}
7697
case TableReferenceType::JOIN: {
7798
auto &join = (JoinRef &)ref;
78-
ExtractTablesFromRef(*join.left, results, "join_left", is_top_level);
79-
ExtractTablesFromRef(*join.right, results, "join_right");
99+
ExtractTablesFromRef(*join.left, results, "join_left", is_top_level, cte_map);
100+
ExtractTablesFromRef(*join.right, results, "join_right", false, cte_map);
80101
break;
81102
}
82103
case TableReferenceType::SUBQUERY: {
83104
auto &subquery = (SubqueryRef &)ref;
84105
if (subquery.subquery && subquery.subquery->node) {
85-
ExtractTablesFromQueryNode(*subquery.subquery->node, results);
106+
ExtractTablesFromQueryNode(*subquery.subquery->node, results, "subquery", cte_map);
86107
}
87108
break;
88109
}
@@ -91,20 +112,31 @@ static void ExtractTablesFromRef(const TableRef &ref, vector<TableRefResult> &re
91112
}
92113
}
93114

94-
static void ExtractTablesFromQueryNode(const QueryNode &node, vector<TableRefResult> &results) {
115+
116+
static void ExtractTablesFromQueryNode(
117+
const duckdb::QueryNode &node,
118+
std::vector<TableRefResult> &results,
119+
const std::string &context,
120+
const duckdb::CommonTableExpressionMap *cte_map
121+
) {
122+
using namespace duckdb;
123+
95124
if (node.type == QueryNodeType::SELECT_NODE) {
96125
auto &select_node = (SelectNode &)node;
97126

98-
// Handle CTEs
127+
// Emit CTE definitions
99128
for (const auto &entry : select_node.cte_map.map) {
100-
if (entry.second && entry.second->query) {
101-
ExtractTablesFromQueryNode(*entry.second->query->node, results);
129+
results.push_back(TableRefResult{
130+
"", entry.first, "cte"
131+
});
132+
133+
if (entry.second && entry.second->query && entry.second->query->node) {
134+
ExtractTablesFromQueryNode(*entry.second->query->node, results, "from", &select_node.cte_map);
102135
}
103136
}
104-
105137

106138
if (select_node.from_table) {
107-
ExtractTablesFromRef(*select_node.from_table, results, "from", true);
139+
ExtractTablesFromRef(*select_node.from_table, results, context, true, &select_node.cte_map);
108140
}
109141
}
110142
}
@@ -121,6 +153,10 @@ static void MyFunc(ClientContext &context,
121153
parser.ParseQuery(bind_data.sql);
122154

123155
for (auto &stmt : parser.statements) {
156+
if (stmt->type != StatementType::SELECT_STATEMENT) {
157+
throw InvalidInputException("parse_tables only supports SELECT statements");
158+
}
159+
124160
if (stmt->type == StatementType::SELECT_STATEMENT) {
125161
auto &select_stmt = (SelectStatement &)*stmt;
126162
if (select_stmt.node) {

0 commit comments

Comments
 (0)