Skip to content

Commit 9266a47

Browse files
Add function parsing capabilities to parser_tools extension
This commit adds comprehensive function name extraction functionality: - New table function: parse_functions(sql_query) - New scalar functions: parse_function_names(sql_query), parse_functions(sql_query) - Extracts functions from all SQL contexts: SELECT, WHERE, HAVING, ORDER BY, GROUP BY - Supports window functions (e.g., row_number() OVER (...)) - Handles nested function calls with proper context tracking - Includes comprehensive test suite with 279 passing assertions - Follows DuckDB coding conventions and integrates with existing extension architecture 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 1d06e5c commit 9266a47

File tree

7 files changed

+700
-0
lines changed

7 files changed

+700
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ set(EXTENSION_SOURCES
1313
src/parser_tools_extension.cpp
1414
src/parse_tables.cpp
1515
src/parse_where.cpp
16+
src/parse_functions.cpp
1617
)
1718

1819
build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES})

src/include/parse_functions.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include "duckdb.hpp"
4+
#include <string>
5+
#include <vector>
6+
7+
namespace duckdb {
8+
9+
// Forward declarations
10+
class DatabaseInstance;
11+
12+
struct FunctionResult {
13+
std::string function_name;
14+
std::string schema;
15+
std::string context; // The context where this function appears (SELECT, WHERE, etc.)
16+
};
17+
18+
void RegisterParseFunctionsFunction(DatabaseInstance &db);
19+
void RegisterParseFunctionScalarFunction(DatabaseInstance &db);
20+
21+
} // namespace duckdb

src/parse_functions.cpp

Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
#include "parse_functions.hpp"
2+
#include "duckdb.hpp"
3+
#include "duckdb/parser/parser.hpp"
4+
#include "duckdb/parser/statement/select_statement.hpp"
5+
#include "duckdb/parser/query_node/select_node.hpp"
6+
#include "duckdb/parser/expression/function_expression.hpp"
7+
#include "duckdb/parser/expression/window_expression.hpp"
8+
#include "duckdb/parser/parsed_expression_iterator.hpp"
9+
#include "duckdb/parser/result_modifier.hpp"
10+
#include "duckdb/main/extension_util.hpp"
11+
#include "duckdb/function/scalar/nested_functions.hpp"
12+
13+
14+
namespace duckdb {
15+
16+
enum class FunctionContext {
17+
Select,
18+
Where,
19+
Having,
20+
OrderBy,
21+
GroupBy,
22+
Join,
23+
WindowFunction,
24+
Nested
25+
};
26+
27+
inline const char *ToString(FunctionContext context) {
28+
switch (context) {
29+
case FunctionContext::Select: return "select";
30+
case FunctionContext::Where: return "where";
31+
case FunctionContext::Having: return "having";
32+
case FunctionContext::OrderBy: return "order_by";
33+
case FunctionContext::GroupBy: return "group_by";
34+
case FunctionContext::Join: return "join";
35+
case FunctionContext::WindowFunction: return "window";
36+
case FunctionContext::Nested: return "nested";
37+
default: return "unknown";
38+
}
39+
}
40+
41+
struct ParseFunctionsState : public GlobalTableFunctionState {
42+
idx_t row = 0;
43+
vector<FunctionResult> results;
44+
};
45+
46+
struct ParseFunctionsBindData : public TableFunctionData {
47+
string sql;
48+
};
49+
50+
// BIND function: runs during query planning to decide output schema
51+
static unique_ptr<FunctionData> ParseFunctionsBind(ClientContext &context,
52+
TableFunctionBindInput &input,
53+
vector<LogicalType> &return_types,
54+
vector<string> &names) {
55+
56+
string sql_input = StringValue::Get(input.inputs[0]);
57+
58+
// always return the same columns:
59+
return_types = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR};
60+
// function name, schema name, usage context
61+
names = {"function_name", "schema", "context"};
62+
63+
// create a bind data object to hold the SQL input
64+
auto result = make_uniq<ParseFunctionsBindData>();
65+
result->sql = sql_input;
66+
67+
return std::move(result);
68+
}
69+
70+
// INIT function: runs before table function execution
71+
static unique_ptr<GlobalTableFunctionState> ParseFunctionsInit(ClientContext &context,
72+
TableFunctionInitInput &input) {
73+
return make_uniq<ParseFunctionsState>();
74+
}
75+
76+
class FunctionExtractor {
77+
public:
78+
static void ExtractFromExpression(const ParsedExpression &expr,
79+
std::vector<FunctionResult> &results,
80+
FunctionContext context = FunctionContext::Select) {
81+
if (expr.expression_class == ExpressionClass::FUNCTION) {
82+
auto &func = (FunctionExpression &)expr;
83+
results.push_back(FunctionResult{
84+
func.function_name,
85+
func.schema.empty() ? "main" : func.schema,
86+
ToString(context)
87+
});
88+
89+
// For nested function calls within this function, mark as nested
90+
ParsedExpressionIterator::EnumerateChildren(expr, [&](const ParsedExpression &child) {
91+
ExtractFromExpression(child, results, FunctionContext::Nested);
92+
});
93+
} else if (expr.expression_class == ExpressionClass::WINDOW) {
94+
auto &window_expr = (WindowExpression &)expr;
95+
results.push_back(FunctionResult{
96+
window_expr.function_name,
97+
window_expr.schema.empty() ? "main" : window_expr.schema,
98+
ToString(context)
99+
});
100+
101+
// Extract functions from window function arguments
102+
for (const auto &child : window_expr.children) {
103+
if (child) {
104+
ExtractFromExpression(*child, results, FunctionContext::Nested);
105+
}
106+
}
107+
108+
// Extract functions from PARTITION BY expressions
109+
for (const auto &partition : window_expr.partitions) {
110+
if (partition) {
111+
ExtractFromExpression(*partition, results, FunctionContext::Nested);
112+
}
113+
}
114+
115+
// Extract functions from ORDER BY expressions
116+
for (const auto &order : window_expr.orders) {
117+
if (order.expression) {
118+
ExtractFromExpression(*order.expression, results, FunctionContext::Nested);
119+
}
120+
}
121+
122+
// Extract functions from argument ordering expressions
123+
for (const auto &arg_order : window_expr.arg_orders) {
124+
if (arg_order.expression) {
125+
ExtractFromExpression(*arg_order.expression, results, FunctionContext::Nested);
126+
}
127+
}
128+
129+
// Extract functions from frame expressions
130+
if (window_expr.start_expr) {
131+
ExtractFromExpression(*window_expr.start_expr, results, FunctionContext::Nested);
132+
}
133+
if (window_expr.end_expr) {
134+
ExtractFromExpression(*window_expr.end_expr, results, FunctionContext::Nested);
135+
}
136+
if (window_expr.offset_expr) {
137+
ExtractFromExpression(*window_expr.offset_expr, results, FunctionContext::Nested);
138+
}
139+
if (window_expr.default_expr) {
140+
ExtractFromExpression(*window_expr.default_expr, results, FunctionContext::Nested);
141+
}
142+
143+
// Extract functions from filter expression
144+
if (window_expr.filter_expr) {
145+
ExtractFromExpression(*window_expr.filter_expr, results, FunctionContext::Nested);
146+
}
147+
} else {
148+
// For non-function expressions, preserve the current context
149+
ParsedExpressionIterator::EnumerateChildren(expr, [&](const ParsedExpression &child) {
150+
ExtractFromExpression(child, results, context);
151+
});
152+
}
153+
}
154+
155+
static void ExtractFromExpressionList(const vector<unique_ptr<ParsedExpression>> &expressions,
156+
std::vector<FunctionResult> &results,
157+
FunctionContext context) {
158+
for (const auto &expr : expressions) {
159+
if (expr) {
160+
ExtractFromExpression(*expr, results, context);
161+
}
162+
}
163+
}
164+
};
165+
166+
167+
static void ExtractFunctionsFromQueryNode(const QueryNode &node, std::vector<FunctionResult> &results) {
168+
if (node.type == QueryNodeType::SELECT_NODE) {
169+
auto &select_node = (SelectNode &)node;
170+
171+
// Extract from CTEs first (to match expected order in tests)
172+
for (const auto &cte : select_node.cte_map.map) {
173+
if (cte.second && cte.second->query && cte.second->query->node) {
174+
ExtractFunctionsFromQueryNode(*cte.second->query->node, results);
175+
}
176+
}
177+
178+
// Extract from SELECT list
179+
FunctionExtractor::ExtractFromExpressionList(select_node.select_list, results, FunctionContext::Select);
180+
181+
// Extract from WHERE clause
182+
if (select_node.where_clause) {
183+
FunctionExtractor::ExtractFromExpression(*select_node.where_clause, results, FunctionContext::Where);
184+
}
185+
186+
// Extract from GROUP BY clause
187+
FunctionExtractor::ExtractFromExpressionList(select_node.groups.group_expressions, results, FunctionContext::GroupBy);
188+
189+
// Extract from HAVING clause
190+
if (select_node.having) {
191+
FunctionExtractor::ExtractFromExpression(*select_node.having, results, FunctionContext::Having);
192+
}
193+
194+
// Extract from ORDER BY clause
195+
for (const auto &modifier : select_node.modifiers) {
196+
if (modifier->type == ResultModifierType::ORDER_MODIFIER) {
197+
auto &order_modifier = (OrderModifier &)*modifier;
198+
for (const auto &order : order_modifier.orders) {
199+
if (order.expression) {
200+
FunctionExtractor::ExtractFromExpression(*order.expression, results, FunctionContext::OrderBy);
201+
}
202+
}
203+
}
204+
}
205+
}
206+
}
207+
208+
static void ExtractFunctionsFromSQL(const std::string &sql, std::vector<FunctionResult> &results) {
209+
Parser parser;
210+
211+
try {
212+
parser.ParseQuery(sql);
213+
} catch (const ParserException &ex) {
214+
// swallow parser exceptions to make this function more robust. is_parsable can be used if needed
215+
return;
216+
}
217+
218+
for (auto &stmt : parser.statements) {
219+
if (stmt->type == StatementType::SELECT_STATEMENT) {
220+
auto &select_stmt = (SelectStatement &)*stmt;
221+
if (select_stmt.node) {
222+
ExtractFunctionsFromQueryNode(*select_stmt.node, results);
223+
}
224+
}
225+
}
226+
}
227+
228+
static void ParseFunctionsFunction(ClientContext &context,
229+
TableFunctionInput &data,
230+
DataChunk &output) {
231+
auto &state = (ParseFunctionsState &)*data.global_state;
232+
auto &bind_data = (ParseFunctionsBindData &)*data.bind_data;
233+
234+
if (state.results.empty() && state.row == 0) {
235+
ExtractFunctionsFromSQL(bind_data.sql, state.results);
236+
}
237+
238+
if (state.row >= state.results.size()) {
239+
return;
240+
}
241+
242+
auto &func = state.results[state.row];
243+
output.SetCardinality(1);
244+
output.SetValue(0, 0, Value(func.function_name));
245+
output.SetValue(1, 0, Value(func.schema));
246+
output.SetValue(2, 0, Value(func.context));
247+
248+
state.row++;
249+
}
250+
251+
static void ParseFunctionNamesScalarFunction(DataChunk &args, ExpressionState &state, Vector &result) {
252+
UnaryExecutor::Execute<string_t, list_entry_t>(args.data[0], result, args.size(),
253+
[&result](string_t query) -> list_entry_t {
254+
// Parse the SQL query and extract function names
255+
auto query_string = query.GetString();
256+
std::vector<FunctionResult> parsed_functions;
257+
ExtractFunctionsFromSQL(query_string, parsed_functions);
258+
259+
auto current_size = ListVector::GetListSize(result);
260+
auto number_of_functions = parsed_functions.size();
261+
auto new_size = current_size + number_of_functions;
262+
263+
// grow list if needed
264+
if (ListVector::GetListCapacity(result) < new_size) {
265+
ListVector::Reserve(result, new_size);
266+
}
267+
268+
// Write the function names into the child vector
269+
auto functions = FlatVector::GetData<string_t>(ListVector::GetEntry(result));
270+
for (size_t i = 0; i < parsed_functions.size(); i++) {
271+
auto &func = parsed_functions[i];
272+
functions[current_size + i] = StringVector::AddStringOrBlob(ListVector::GetEntry(result), func.function_name);
273+
}
274+
275+
// Update size
276+
ListVector::SetListSize(result, new_size);
277+
278+
return list_entry_t(current_size, number_of_functions);
279+
});
280+
}
281+
282+
static void ParseFunctionsScalarFunction_struct(DataChunk &args, ExpressionState &state, Vector &result) {
283+
UnaryExecutor::Execute<string_t, list_entry_t>(args.data[0], result, args.size(),
284+
[&result](string_t query) -> list_entry_t {
285+
// Parse the SQL query and extract function names
286+
auto query_string = query.GetString();
287+
std::vector<FunctionResult> parsed_functions;
288+
ExtractFunctionsFromSQL(query_string, parsed_functions);
289+
290+
auto current_size = ListVector::GetListSize(result);
291+
auto number_of_functions = parsed_functions.size();
292+
auto new_size = current_size + number_of_functions;
293+
294+
// Grow list vector if needed
295+
if (ListVector::GetListCapacity(result) < new_size) {
296+
ListVector::Reserve(result, new_size);
297+
}
298+
299+
// Get the struct child vector of the list
300+
auto &struct_vector = ListVector::GetEntry(result);
301+
302+
// Ensure list size is updated
303+
ListVector::SetListSize(result, new_size);
304+
305+
// Get the fields in the STRUCT
306+
auto &entries = StructVector::GetEntries(struct_vector);
307+
auto &function_name_entry = *entries[0]; // "function_name" field
308+
auto &schema_entry = *entries[1]; // "schema" field
309+
auto &context_entry = *entries[2]; // "context" field
310+
311+
auto function_name_data = FlatVector::GetData<string_t>(function_name_entry);
312+
auto schema_data = FlatVector::GetData<string_t>(schema_entry);
313+
auto context_data = FlatVector::GetData<string_t>(context_entry);
314+
315+
for (size_t i = 0; i < number_of_functions; i++) {
316+
const auto &func = parsed_functions[i];
317+
auto idx = current_size + i;
318+
319+
function_name_data[idx] = StringVector::AddStringOrBlob(function_name_entry, func.function_name);
320+
schema_data[idx] = StringVector::AddStringOrBlob(schema_entry, func.schema);
321+
context_data[idx] = StringVector::AddStringOrBlob(context_entry, func.context);
322+
}
323+
324+
return list_entry_t(current_size, number_of_functions);
325+
});
326+
}
327+
328+
// Extension scaffolding
329+
// ---------------------------------------------------
330+
331+
void RegisterParseFunctionsFunction(DatabaseInstance &db) {
332+
TableFunction tf("parse_functions", {LogicalType::VARCHAR}, ParseFunctionsFunction, ParseFunctionsBind, ParseFunctionsInit);
333+
ExtensionUtil::RegisterFunction(db, tf);
334+
}
335+
336+
void RegisterParseFunctionScalarFunction(DatabaseInstance &db) {
337+
// parse_function_names is a scalar function that returns a list of function names
338+
ScalarFunction sf("parse_function_names", {LogicalType::VARCHAR}, LogicalType::LIST(LogicalType::VARCHAR), ParseFunctionNamesScalarFunction);
339+
ExtensionUtil::RegisterFunction(db, sf);
340+
341+
// parse_functions_struct is a scalar function that returns a list of structs
342+
auto return_type = LogicalType::LIST(LogicalType::STRUCT({
343+
{"function_name", LogicalType::VARCHAR},
344+
{"schema", LogicalType::VARCHAR},
345+
{"context", LogicalType::VARCHAR}
346+
}));
347+
ScalarFunction sf_struct("parse_functions", {LogicalType::VARCHAR}, return_type, ParseFunctionsScalarFunction_struct);
348+
ExtensionUtil::RegisterFunction(db, sf_struct);
349+
}
350+
351+
352+
353+
} // namespace duckdb

0 commit comments

Comments
 (0)