@@ -24,6 +24,16 @@ inline const char *ToString(TableContext context) {
2424 }
2525}
2626
27+ inline const TableContext FromString (const char *context) {
28+ if (strcmp (context, " from" ) == 0 ) return TableContext::From;
29+ if (strcmp (context, " join_left" ) == 0 ) return TableContext::JoinLeft;
30+ if (strcmp (context, " join_right" ) == 0 ) return TableContext::JoinRight;
31+ if (strcmp (context, " from_cte" ) == 0 ) return TableContext::FromCTE;
32+ if (strcmp (context, " cte" ) == 0 ) return TableContext::CTE;
33+ if (strcmp (context, " subquery" ) == 0 ) return TableContext::Subquery;
34+ throw InternalException (" Unknown table context: %s" , context);
35+ }
36+
2737struct ParseTablesState : public GlobalTableFunctionState {
2838 idx_t row = 0 ;
2939 vector<TableRefResult> results;
@@ -156,6 +166,22 @@ void ExtractTablesFromSQL(const std::string &sql, std::vector<TableRefResult> &r
156166 }
157167}
158168
169+ void ExtractTablesFromSQL (const std::string & sql, std::vector<TableRefResult> &result, std::unordered_set<std::string> excluded_types) {
170+ std::vector<TableRefResult> temp_result;
171+ ExtractTablesFromSQL (sql, temp_result);
172+ std::unordered_set<TableContext> e_types;
173+
174+ for (auto &type : excluded_types) {
175+ e_types.insert (FromString (type.c_str ()));
176+ }
177+
178+ for (auto &table : temp_result) {
179+ if (e_types.count (table.context ) == 0 ) {
180+ result.push_back (table);
181+ }
182+ }
183+ }
184+
159185static void ParseTablesFunction (ClientContext &context,
160186 TableFunctionInput &data,
161187 DataChunk &output) {
@@ -180,16 +206,37 @@ static void ParseTablesFunction(ClientContext &context,
180206}
181207
182208static void ParseTablesScalarFunction (DataChunk &args, ExpressionState &state, Vector &result) {
209+ Vector flag (LogicalType::BOOLEAN);
210+
211+ // Allow for the optional boolean argument. if not provided, default to true
212+ if (args.ColumnCount () == 1 ) {
213+ // create a default argument to pass below. we'll use a constant vector since all values are the same
214+ Vector c (LogicalType::BOOLEAN);
215+ c.Reference (Value::BOOLEAN (true ));
216+ ConstantVector::Reference (flag, c, 0 , args.size ());
217+ } else if (args.ColumnCount () == 2 ) {
218+ flag.Reference (args.data [1 ]);
219+ } else {
220+ throw InvalidInputException (" parse_tables() expects 1 or 2 arguments" );
221+ }
222+
183223 // Execute does the heavy lifting of iterating over the input data
184224 // and calling the provided lambda function for each input value.
185225 // The lambda function is responsible for parsing the SQL query and
186226 // extracting the table names.
187- UnaryExecutor ::Execute<string_t , list_entry_t >(args.data [0 ], result, args.size (),
188- [&result](string_t query) -> list_entry_t {
227+ BinaryExecutor ::Execute<string_t , bool , list_entry_t >(args.data [0 ], flag , result, args.size (),
228+ [&result](string_t query, bool exclude_cte ) -> list_entry_t {
189229 // Parse the SQL query and extract table names
190230 auto query_string = query.GetString ();
191231 std::vector<TableRefResult> parsed_tables;
192- ExtractTablesFromSQL (query_string, parsed_tables);
232+ if (exclude_cte) {
233+ std::unordered_set<std::string> excluded_types;
234+ excluded_types.insert (" cte" );
235+ ExtractTablesFromSQL (query_string, parsed_tables, excluded_types);
236+ } else {
237+ ExtractTablesFromSQL (query_string, parsed_tables);
238+ }
239+
193240
194241 auto current_size = ListVector::GetListSize (result);
195242 auto number_of_tables = parsed_tables.size ();
@@ -223,8 +270,14 @@ void RegisterParseTablesFunction(DatabaseInstance &db) {
223270}
224271
225272void RegisterParseTableScalarFunction (DatabaseInstance &db) {
226- ScalarFunction sf ( " parse_tables" , {LogicalType::VARCHAR}, LogicalType::LIST (LogicalType::VARCHAR), ParseTablesScalarFunction);
227- ExtensionUtil::RegisterFunction (db, sf);
273+ // parse tables is overloaded, allowing for an optional boolean argument
274+ // that indicates whether to include CTEs in the result
275+ // usage: parse_tables(sql_query [, include_cte])
276+ ScalarFunctionSet set (" parse_tables" );
277+ set.AddFunction (ScalarFunction ({LogicalType::VARCHAR}, LogicalType::LIST (LogicalType::VARCHAR), ParseTablesScalarFunction));
278+ set.AddFunction (ScalarFunction ({LogicalType::VARCHAR, LogicalType::BOOLEAN}, LogicalType::LIST (LogicalType::VARCHAR), ParseTablesScalarFunction));
279+
280+ ExtensionUtil::RegisterFunction (db, set);
228281}
229282
230283} // namespace duckdb
0 commit comments