77#include " duckdb/parser/tableref/joinref.hpp"
88#include " duckdb/parser/tableref/subqueryref.hpp"
99#include " duckdb/main/extension_util.hpp"
10+ #include " duckdb/function/scalar/nested_functions.hpp"
11+
1012
1113namespace duckdb {
1214
@@ -140,32 +142,28 @@ static void ExtractTablesFromQueryNode(
140142 }
141143}
142144
145+ void ExtractTablesFromSQL (const std::string &sql, std::vector<TableRefResult> &results) {
146+ Parser parser;
147+ parser.ParseQuery (sql);
148+
149+ for (auto &stmt : parser.statements ) {
150+ if (stmt->type == StatementType::SELECT_STATEMENT) {
151+ auto &select_stmt = (SelectStatement &)*stmt;
152+ if (select_stmt.node ) {
153+ ExtractTablesFromQueryNode (*select_stmt.node , results);
154+ }
155+ }
156+ }
157+ }
158+
143159static void ParseTablesFunction (ClientContext &context,
144160 TableFunctionInput &data,
145161 DataChunk &output) {
146162 auto &state = (ParseTablesState &)*data.global_state ;
147163 auto &bind_data = (ParseTablesBindData &)*data.bind_data ;
148164
149165 if (state.results .empty () && state.row == 0 ) {
150- try {
151- Parser parser;
152- parser.ParseQuery (bind_data.sql );
153-
154- for (auto &stmt : parser.statements ) {
155- if (stmt->type != StatementType::SELECT_STATEMENT) {
156- throw InvalidInputException (" parse_tables only supports SELECT statements" );
157- }
158-
159- if (stmt->type == StatementType::SELECT_STATEMENT) {
160- auto &select_stmt = (SelectStatement &)*stmt;
161- if (select_stmt.node ) {
162- ExtractTablesFromQueryNode (*select_stmt.node , state.results );
163- }
164- }
165- }
166- } catch (const std::exception &ex) {
167- throw InvalidInputException (" Failed to parse SQL: %s" , ex.what ());
168- }
166+ ExtractTablesFromSQL (bind_data.sql , state.results );
169167 }
170168
171169 if (state.row >= state.results .size ()) {
@@ -181,6 +179,41 @@ static void ParseTablesFunction(ClientContext &context,
181179 state.row ++;
182180}
183181
182+ static void ParseTablesScalarFunction (DataChunk &args, ExpressionState &state, Vector &result) {
183+ // Execute does the heavy lifting of iterating over the input data
184+ // and calling the provided lambda function for each input value.
185+ // The lambda function is responsible for parsing the SQL query and
186+ // extracting the table names.
187+ UnaryExecutor::Execute<string_t , list_entry_t >(args.data [0 ], result, args.size (),
188+ [&result](string_t query) -> list_entry_t {
189+ // Parse the SQL query and extract table names
190+ auto query_string = query.GetString ();
191+ std::vector<TableRefResult> parsed_tables;
192+ ExtractTablesFromSQL (query_string, parsed_tables);
193+
194+ auto current_size = ListVector::GetListSize (result);
195+ auto number_of_tables = parsed_tables.size ();
196+ auto new_size = current_size + number_of_tables;
197+
198+ // grow list if needed
199+ if (ListVector::GetListCapacity (result) < new_size) {
200+ ListVector::Reserve (result, new_size);
201+ }
202+
203+ // Write the string into the child vector
204+ auto tables = FlatVector::GetData<string_t >(ListVector::GetEntry (result));
205+ for (size_t i = 0 ; i < parsed_tables.size (); i++) {
206+ auto &table = parsed_tables[i];
207+ tables[current_size + i] = StringVector::AddStringOrBlob (ListVector::GetEntry (result), table.table );
208+ }
209+
210+ // Update size
211+ ListVector::SetListSize (result, new_size);
212+
213+ return list_entry_t (current_size, number_of_tables);
214+ });
215+ }
216+
184217// Extension scaffolding
185218// ---------------------------------------------------
186219
@@ -189,4 +222,9 @@ void RegisterParseTablesFunction(DatabaseInstance &db) {
189222 ExtensionUtil::RegisterFunction (db, tf);
190223}
191224
225+ void RegisterParseTableScalarFunction (DatabaseInstance &db) {
226+ ScalarFunction sf ( " parse_tables" , {LogicalType::VARCHAR}, LogicalType::LIST (LogicalType::VARCHAR), ParseTablesScalarFunction);
227+ ExtensionUtil::RegisterFunction (db, sf);
228+ }
229+
192230} // namespace duckdb
0 commit comments