77#include " duckdb/function/scalar_function.hpp"
88#include " duckdb/main/extension_util.hpp"
99#include < duckdb/parser/parsed_data/create_scalar_function_info.hpp>
10-
1110#include " duckdb/parser/parser.hpp"
1211#include " duckdb/parser/statement/select_statement.hpp"
1312#include " duckdb/parser/query_node/select_node.hpp"
1413#include " duckdb/parser/tableref/basetableref.hpp"
1514#include " duckdb/parser/tableref/joinref.hpp"
1615#include " duckdb/parser/tableref/subqueryref.hpp"
16+ #include " duckdb/parser/statement/insert_statement.hpp"
1717
1818namespace duckdb {
1919
@@ -60,29 +60,50 @@ static unique_ptr<GlobalTableFunctionState> MyInit(ClientContext &context,
6060 return make_uniq<ParseTablesState>();
6161}
6262
63- static void ExtractTablesFromQueryNode (const QueryNode &node, vector<TableRefResult> &results);
63+ static void ExtractTablesFromQueryNode (
64+ const duckdb::QueryNode &node,
65+ std::vector<TableRefResult> &results,
66+ const std::string &context = " from" ,
67+ const duckdb::CommonTableExpressionMap *cte_map = nullptr
68+ );
69+
70+ static void ExtractTablesFromRef (
71+ const duckdb::TableRef &ref,
72+ std::vector<TableRefResult> &results,
73+ const std::string &context = " from" ,
74+ bool is_top_level = false ,
75+ const duckdb::CommonTableExpressionMap *cte_map = nullptr
76+ ) {
77+ using namespace duckdb ;
6478
65- static void ExtractTablesFromRef (const TableRef &ref, vector<TableRefResult> &results, const string &context = " from" , bool is_top_level = false ) {
6679 switch (ref.type ) {
6780 case TableReferenceType::BASE_TABLE: {
6881 auto &base = (BaseTableRef &)ref;
82+ std::string context_label = context;
83+
84+ if (cte_map && cte_map->map .find (base.table_name ) != cte_map->map .end ()) {
85+ context_label = " from_cte" ;
86+ } else if (is_top_level) {
87+ context_label = " from" ;
88+ }
89+
6990 results.push_back (TableRefResult{
7091 base.schema_name .empty () ? " main" : base.schema_name ,
7192 base.table_name ,
72- is_top_level ? " from " : context
93+ context_label
7394 });
7495 break ;
7596 }
7697 case TableReferenceType::JOIN: {
7798 auto &join = (JoinRef &)ref;
78- ExtractTablesFromRef (*join.left , results, " join_left" , is_top_level);
79- ExtractTablesFromRef (*join.right , results, " join_right" );
99+ ExtractTablesFromRef (*join.left , results, " join_left" , is_top_level, cte_map );
100+ ExtractTablesFromRef (*join.right , results, " join_right" , false , cte_map );
80101 break ;
81102 }
82103 case TableReferenceType::SUBQUERY: {
83104 auto &subquery = (SubqueryRef &)ref;
84105 if (subquery.subquery && subquery.subquery ->node ) {
85- ExtractTablesFromQueryNode (*subquery.subquery ->node , results);
106+ ExtractTablesFromQueryNode (*subquery.subquery ->node , results, " subquery " , cte_map );
86107 }
87108 break ;
88109 }
@@ -91,20 +112,31 @@ static void ExtractTablesFromRef(const TableRef &ref, vector<TableRefResult> &re
91112 }
92113}
93114
94- static void ExtractTablesFromQueryNode (const QueryNode &node, vector<TableRefResult> &results) {
115+
116+ static void ExtractTablesFromQueryNode (
117+ const duckdb::QueryNode &node,
118+ std::vector<TableRefResult> &results,
119+ const std::string &context,
120+ const duckdb::CommonTableExpressionMap *cte_map
121+ ) {
122+ using namespace duckdb ;
123+
95124 if (node.type == QueryNodeType::SELECT_NODE) {
96125 auto &select_node = (SelectNode &)node;
97126
98- // Handle CTEs
127+ // Emit CTE definitions
99128 for (const auto &entry : select_node.cte_map .map ) {
100- if (entry.second && entry.second ->query ) {
101- ExtractTablesFromQueryNode (*entry.second ->query ->node , results);
129+ results.push_back (TableRefResult{
130+ " " , entry.first , " cte"
131+ });
132+
133+ if (entry.second && entry.second ->query && entry.second ->query ->node ) {
134+ ExtractTablesFromQueryNode (*entry.second ->query ->node , results, " from" , &select_node.cte_map );
102135 }
103136 }
104-
105137
106138 if (select_node.from_table ) {
107- ExtractTablesFromRef (*select_node.from_table , results, " from " , true );
139+ ExtractTablesFromRef (*select_node.from_table , results, context , true , &select_node. cte_map );
108140 }
109141 }
110142}
@@ -121,6 +153,10 @@ static void MyFunc(ClientContext &context,
121153 parser.ParseQuery (bind_data.sql );
122154
123155 for (auto &stmt : parser.statements ) {
156+ if (stmt->type != StatementType::SELECT_STATEMENT) {
157+ throw InvalidInputException (" parse_tables only supports SELECT statements" );
158+ }
159+
124160 if (stmt->type == StatementType::SELECT_STATEMENT) {
125161 auto &select_stmt = (SelectStatement &)*stmt;
126162 if (select_stmt.node ) {
0 commit comments