Skip to content

Commit 3e23b09

Browse files
committed
refactor to support excluding CTE
1 parent 0282e54 commit 3e23b09

File tree

2 files changed

+59
-5
lines changed

2 files changed

+59
-5
lines changed

src/include/parse_tables.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ enum class TableContext {
1717
};
1818

1919
const char *ToString(TableContext context);
20+
const TableContext FromString(const char *context);
2021

2122
struct TableRefResult {
2223
std::string schema;

src/parse_tables.cpp

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ inline const char *ToString(TableContext context) {
2424
}
2525
}
2626

27+
inline const TableContext FromString(const char *context) {
28+
if (strcmp(context, "from") == 0) return TableContext::From;
29+
if (strcmp(context, "join_left") == 0) return TableContext::JoinLeft;
30+
if (strcmp(context, "join_right") == 0) return TableContext::JoinRight;
31+
if (strcmp(context, "from_cte") == 0) return TableContext::FromCTE;
32+
if (strcmp(context, "cte") == 0) return TableContext::CTE;
33+
if (strcmp(context, "subquery") == 0) return TableContext::Subquery;
34+
throw InternalException("Unknown table context: %s", context);
35+
}
36+
2737
struct ParseTablesState : public GlobalTableFunctionState {
2838
idx_t row = 0;
2939
vector<TableRefResult> results;
@@ -156,6 +166,22 @@ void ExtractTablesFromSQL(const std::string &sql, std::vector<TableRefResult> &r
156166
}
157167
}
158168

169+
void ExtractTablesFromSQL(const std::string & sql, std::vector<TableRefResult> &result, std::unordered_set<std::string> excluded_types) {
170+
std::vector<TableRefResult> temp_result;
171+
ExtractTablesFromSQL(sql, temp_result);
172+
std::unordered_set<TableContext> e_types;
173+
174+
for (auto &type : excluded_types) {
175+
e_types.insert(FromString(type.c_str()));
176+
}
177+
178+
for (auto &table : temp_result) {
179+
if (e_types.count(table.context) == 0) {
180+
result.push_back(table);
181+
}
182+
}
183+
}
184+
159185
static void ParseTablesFunction(ClientContext &context,
160186
TableFunctionInput &data,
161187
DataChunk &output) {
@@ -180,16 +206,37 @@ static void ParseTablesFunction(ClientContext &context,
180206
}
181207

182208
static void ParseTablesScalarFunction(DataChunk &args, ExpressionState &state, Vector &result) {
209+
Vector flag(LogicalType::BOOLEAN);
210+
211+
// Allow for the optional boolean argument. if not provided, default to true
212+
if (args.ColumnCount() == 1) {
213+
// create a default argument to pass below. we'll use a constant vector since all values are the same
214+
Vector c(LogicalType::BOOLEAN);
215+
c.Reference(Value::BOOLEAN(true));
216+
ConstantVector::Reference(flag, c, 0, args.size());
217+
} else if (args.ColumnCount() == 2) {
218+
flag.Reference(args.data[1]);
219+
} else {
220+
throw InvalidInputException("parse_tables() expects 1 or 2 arguments");
221+
}
222+
183223
// Execute does the heavy lifting of iterating over the input data
184224
// and calling the provided lambda function for each input value.
185225
// The lambda function is responsible for parsing the SQL query and
186226
// extracting the table names.
187-
UnaryExecutor::Execute<string_t, list_entry_t>(args.data[0], result, args.size(),
188-
[&result](string_t query) -> list_entry_t {
227+
BinaryExecutor::Execute<string_t, bool, list_entry_t>(args.data[0], flag, result, args.size(),
228+
[&result](string_t query, bool exclude_cte) -> list_entry_t {
189229
// Parse the SQL query and extract table names
190230
auto query_string = query.GetString();
191231
std::vector<TableRefResult> parsed_tables;
192-
ExtractTablesFromSQL(query_string, parsed_tables);
232+
if (exclude_cte) {
233+
std::unordered_set<std::string> excluded_types;
234+
excluded_types.insert("cte");
235+
ExtractTablesFromSQL(query_string, parsed_tables, excluded_types);
236+
} else {
237+
ExtractTablesFromSQL(query_string, parsed_tables);
238+
}
239+
193240

194241
auto current_size = ListVector::GetListSize(result);
195242
auto number_of_tables = parsed_tables.size();
@@ -223,8 +270,14 @@ void RegisterParseTablesFunction(DatabaseInstance &db) {
223270
}
224271

225272
void RegisterParseTableScalarFunction(DatabaseInstance &db) {
226-
ScalarFunction sf( "parse_tables", {LogicalType::VARCHAR}, LogicalType::LIST(LogicalType::VARCHAR), ParseTablesScalarFunction);
227-
ExtensionUtil::RegisterFunction(db, sf);
273+
// parse tables is overloaded, allowing for an optional boolean argument
274+
// that indicates whether to include CTEs in the result
275+
// usage: parse_tables(sql_query [, include_cte])
276+
ScalarFunctionSet set("parse_tables");
277+
set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::LIST(LogicalType::VARCHAR), ParseTablesScalarFunction));
278+
set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BOOLEAN}, LogicalType::LIST(LogicalType::VARCHAR), ParseTablesScalarFunction));
279+
280+
ExtensionUtil::RegisterFunction(db, set);
228281
}
229282

230283
} // namespace duckdb

0 commit comments

Comments
 (0)