|
8 | 8 | #include "duckdb/main/extension_util.hpp" |
9 | 9 | #include <duckdb/parser/parsed_data/create_scalar_function_info.hpp> |
10 | 10 |
|
| 11 | +#include "duckdb/parser/parser.hpp" |
| 12 | +#include "duckdb/parser/statement/select_statement.hpp" |
| 13 | +#include "duckdb/parser/query_node/select_node.hpp" |
| 14 | +#include "duckdb/parser/tableref/basetableref.hpp" |
| 15 | +#include "duckdb/parser/tableref/joinref.hpp" |
| 16 | +#include "duckdb/parser/tableref/subqueryref.hpp" |
11 | 17 |
|
12 | 18 | namespace duckdb { |
13 | 19 |
|
@@ -54,31 +60,107 @@ static unique_ptr<GlobalTableFunctionState> MyInit(ClientContext &context, |
54 | 60 | return make_uniq<ParseTablesState>(); |
55 | 61 | } |
56 | 62 |
|
57 | | -// EXECUTE function: produces rows |
58 | | -static void MyFunc(ClientContext &context, |
59 | | - TableFunctionInput &data, |
60 | | - DataChunk &output) { |
| 63 | +static void ExtractTablesFromQueryNode(const QueryNode &node, vector<TableRefResult> &results); |
61 | 64 |
|
62 | | - auto &state = (ParseTablesState &)*data.global_state; |
| 65 | +static void ExtractTablesFromRef(const TableRef &ref, vector<TableRefResult> &results, const string &context = "from") { |
| 66 | + std::cout << "Ref type: " << (int)ref.type << std::endl; |
| 67 | + if (ref.type == TableReferenceType::BASE_TABLE) { |
| 68 | + auto &base = (BaseTableRef &)ref; |
| 69 | + std::cout << "Found base table: " << base.schema_name << "." << base.table_name << std::endl; |
| 70 | + } |
| 71 | + |
| 72 | + |
| 73 | + switch (ref.type) { |
| 74 | + case TableReferenceType::BASE_TABLE: { |
| 75 | + auto &base = (BaseTableRef &)ref; |
| 76 | + results.push_back(TableRefResult{ |
| 77 | + base.schema_name.empty() ? "main" : base.schema_name, |
| 78 | + base.table_name, |
| 79 | + context |
| 80 | + }); |
| 81 | + break; |
| 82 | + } |
| 83 | + case TableReferenceType::JOIN: { |
| 84 | + auto &join = (JoinRef &)ref; |
| 85 | + ExtractTablesFromRef(*join.left, results, "join_left"); |
| 86 | + ExtractTablesFromRef(*join.right, results, "join_right"); |
| 87 | + break; |
| 88 | + } |
| 89 | + case TableReferenceType::SUBQUERY: { |
| 90 | + auto &subquery = (SubqueryRef &)ref; |
| 91 | + if (subquery.subquery && subquery.subquery->node) { |
| 92 | + ExtractTablesFromQueryNode(*subquery.subquery->node, results); |
| 93 | + } |
| 94 | + break; |
| 95 | + } |
| 96 | + default: |
| 97 | + break; |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +static void ExtractTablesFromQueryNode(const QueryNode &node, vector<TableRefResult> &results) { |
| 102 | + if (node.type == QueryNodeType::SELECT_NODE) { |
| 103 | + auto &select_node = (SelectNode &)node; |
| 104 | + |
| 105 | + std::cout << "Extracting from query node" << std::endl; |
63 | 106 |
|
64 | | - std::cout << "row: " << state.row << std::endl; |
65 | 107 |
|
| 108 | + // Handle CTEs |
| 109 | + |
| 110 | + for (const auto &entry : select_node.cte_map.map) { |
| 111 | + if (entry.second && entry.second->query) { |
| 112 | + ExtractTablesFromQueryNode(*entry.second->query->node, results); |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + |
| 117 | + if (select_node.from_table) { |
| 118 | + ExtractTablesFromRef(*select_node.from_table, results, "from"); |
| 119 | + } |
| 120 | + } |
| 121 | +} |
| 122 | + |
| 123 | +static void MyFunc(ClientContext &context, |
| 124 | + TableFunctionInput &data, |
| 125 | + DataChunk &output) { |
| 126 | + auto &state = (ParseTablesState &)*data.global_state; |
66 | 127 | auto &bind_data = (ParseTablesBindData &)*data.bind_data; |
67 | 128 |
|
68 | | - std::cout << "Executing for SQL: " << bind_data.sql << std::endl; |
| 129 | + static vector<TableRefResult> results; |
| 130 | + static bool parsed = false; |
| 131 | + |
| 132 | + if (!parsed) { |
| 133 | + try { |
| 134 | + Parser parser; |
| 135 | + parser.ParseQuery(bind_data.sql); |
| 136 | + |
| 137 | + std::cout << "Parsed " << parser.statements.size() << " statements" << std::endl; |
| 138 | + |
| 139 | + |
| 140 | + for (auto &stmt : parser.statements) { |
| 141 | + if (stmt->type == StatementType::SELECT_STATEMENT) { |
| 142 | + auto &select_stmt = (SelectStatement &)*stmt; |
| 143 | + if (select_stmt.node) { |
| 144 | + ExtractTablesFromQueryNode(*select_stmt.node, results); |
| 145 | + } |
| 146 | + } |
| 147 | + } |
| 148 | + parsed = true; |
| 149 | + } catch (const std::exception &ex) { |
| 150 | + throw InvalidInputException("Failed to parse SQL: %s", ex.what()); |
| 151 | + } |
| 152 | + } |
69 | 153 |
|
70 | | - if (state.row >= 1) { |
71 | | - return; // no more rows to produce |
| 154 | + if (state.row >= results.size()) { |
| 155 | + return; |
72 | 156 | } |
73 | 157 |
|
74 | | - // Example: single string column with 1 row |
75 | | - // auto row_count = 1; |
| 158 | + auto &ref = results[state.row]; |
76 | 159 | output.SetCardinality(1); |
| 160 | + output.SetValue(0, 0, Value(ref.schema)); |
| 161 | + output.SetValue(1, 0, Value(ref.table)); |
| 162 | + output.SetValue(2, 0, Value(ref.context)); |
77 | 163 |
|
78 | | - output.SetValue(0, 0, Value("my_schema")); |
79 | | - output.SetValue(1, 0, Value("my_table")); |
80 | | - output.SetValue(2, 0, Value("from")); |
81 | | - |
82 | 164 | state.row++; |
83 | 165 | } |
84 | 166 |
|
|
0 commit comments