Skip to content

Commit 0282e54

Browse files
committed
Add scalar form of parse_tables
1 parent c15a216 commit 0282e54

File tree

3 files changed

+59
-19
lines changed

3 files changed

+59
-19
lines changed

src/include/parse_tables.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,6 @@ struct TableRefResult {
2727
void ExtractTablesFromSQL(const std::string &sql, std::vector<TableRefResult> &results);
2828

2929
void RegisterParseTablesFunction(duckdb::DatabaseInstance &db);
30+
void RegisterParseTableScalarFunction(DatabaseInstance &db);
3031

3132
} // namespace duckdb

src/parse_tables.cpp

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include "duckdb/parser/tableref/joinref.hpp"
88
#include "duckdb/parser/tableref/subqueryref.hpp"
99
#include "duckdb/main/extension_util.hpp"
10+
#include "duckdb/function/scalar/nested_functions.hpp"
11+
1012

1113
namespace duckdb {
1214

@@ -140,32 +142,28 @@ static void ExtractTablesFromQueryNode(
140142
}
141143
}
142144

145+
void ExtractTablesFromSQL(const std::string &sql, std::vector<TableRefResult> &results) {
146+
Parser parser;
147+
parser.ParseQuery(sql);
148+
149+
for (auto &stmt : parser.statements) {
150+
if (stmt->type == StatementType::SELECT_STATEMENT) {
151+
auto &select_stmt = (SelectStatement &)*stmt;
152+
if (select_stmt.node) {
153+
ExtractTablesFromQueryNode(*select_stmt.node, results);
154+
}
155+
}
156+
}
157+
}
158+
143159
static void ParseTablesFunction(ClientContext &context,
144160
TableFunctionInput &data,
145161
DataChunk &output) {
146162
auto &state = (ParseTablesState &)*data.global_state;
147163
auto &bind_data = (ParseTablesBindData &)*data.bind_data;
148164

149165
if (state.results.empty() && state.row == 0) {
150-
try {
151-
Parser parser;
152-
parser.ParseQuery(bind_data.sql);
153-
154-
for (auto &stmt : parser.statements) {
155-
if (stmt->type != StatementType::SELECT_STATEMENT) {
156-
throw InvalidInputException("parse_tables only supports SELECT statements");
157-
}
158-
159-
if (stmt->type == StatementType::SELECT_STATEMENT) {
160-
auto &select_stmt = (SelectStatement &)*stmt;
161-
if (select_stmt.node) {
162-
ExtractTablesFromQueryNode(*select_stmt.node, state.results);
163-
}
164-
}
165-
}
166-
} catch (const std::exception &ex) {
167-
throw InvalidInputException("Failed to parse SQL: %s", ex.what());
168-
}
166+
ExtractTablesFromSQL(bind_data.sql, state.results);
169167
}
170168

171169
if (state.row >= state.results.size()) {
@@ -181,6 +179,41 @@ static void ParseTablesFunction(ClientContext &context,
181179
state.row++;
182180
}
183181

182+
static void ParseTablesScalarFunction(DataChunk &args, ExpressionState &state, Vector &result) {
183+
// Execute does the heavy lifting of iterating over the input data
184+
// and calling the provided lambda function for each input value.
185+
// The lambda function is responsible for parsing the SQL query and
186+
// extracting the table names.
187+
UnaryExecutor::Execute<string_t, list_entry_t>(args.data[0], result, args.size(),
188+
[&result](string_t query) -> list_entry_t {
189+
// Parse the SQL query and extract table names
190+
auto query_string = query.GetString();
191+
std::vector<TableRefResult> parsed_tables;
192+
ExtractTablesFromSQL(query_string, parsed_tables);
193+
194+
auto current_size = ListVector::GetListSize(result);
195+
auto number_of_tables = parsed_tables.size();
196+
auto new_size = current_size + number_of_tables;
197+
198+
// grow list if needed
199+
if (ListVector::GetListCapacity(result) < new_size) {
200+
ListVector::Reserve(result, new_size);
201+
}
202+
203+
// Write the string into the child vector
204+
auto tables = FlatVector::GetData<string_t>(ListVector::GetEntry(result));
205+
for (size_t i = 0; i < parsed_tables.size(); i++) {
206+
auto &table = parsed_tables[i];
207+
tables[current_size + i] = StringVector::AddStringOrBlob(ListVector::GetEntry(result), table.table);
208+
}
209+
210+
// Update size
211+
ListVector::SetListSize(result, new_size);
212+
213+
return list_entry_t(current_size, number_of_tables);
214+
});
215+
}
216+
184217
// Extension scaffolding
185218
// ---------------------------------------------------
186219

@@ -189,4 +222,9 @@ void RegisterParseTablesFunction(DatabaseInstance &db) {
189222
ExtensionUtil::RegisterFunction(db, tf);
190223
}
191224

225+
void RegisterParseTableScalarFunction(DatabaseInstance &db) {
226+
ScalarFunction sf( "parse_tables", {LogicalType::VARCHAR}, LogicalType::LIST(LogicalType::VARCHAR), ParseTablesScalarFunction);
227+
ExtensionUtil::RegisterFunction(db, sf);
228+
}
229+
192230
} // namespace duckdb

src/parser_tools_extension.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace duckdb {
2222

2323
static void LoadInternal(DatabaseInstance &instance) {
2424
RegisterParseTablesFunction(instance);
25+
RegisterParseTableScalarFunction(instance);
2526
}
2627

2728
void ParserToolsExtension::Load(DuckDB &db) {

0 commit comments

Comments
 (0)