|
| 1 | +#include "parse_functions.hpp" |
| 2 | +#include "duckdb.hpp" |
| 3 | +#include "duckdb/parser/parser.hpp" |
| 4 | +#include "duckdb/parser/statement/select_statement.hpp" |
| 5 | +#include "duckdb/parser/query_node/select_node.hpp" |
| 6 | +#include "duckdb/parser/expression/function_expression.hpp" |
| 7 | +#include "duckdb/parser/expression/window_expression.hpp" |
| 8 | +#include "duckdb/parser/parsed_expression_iterator.hpp" |
| 9 | +#include "duckdb/parser/result_modifier.hpp" |
| 10 | +#include "duckdb/main/extension_util.hpp" |
| 11 | +#include "duckdb/function/scalar/nested_functions.hpp" |
| 12 | + |
| 13 | + |
| 14 | +namespace duckdb { |
| 15 | + |
| 16 | +enum class FunctionContext { |
| 17 | + Select, |
| 18 | + Where, |
| 19 | + Having, |
| 20 | + OrderBy, |
| 21 | + GroupBy, |
| 22 | + Join, |
| 23 | + WindowFunction, |
| 24 | + Nested |
| 25 | +}; |
| 26 | + |
| 27 | +inline const char *ToString(FunctionContext context) { |
| 28 | + switch (context) { |
| 29 | + case FunctionContext::Select: return "select"; |
| 30 | + case FunctionContext::Where: return "where"; |
| 31 | + case FunctionContext::Having: return "having"; |
| 32 | + case FunctionContext::OrderBy: return "order_by"; |
| 33 | + case FunctionContext::GroupBy: return "group_by"; |
| 34 | + case FunctionContext::Join: return "join"; |
| 35 | + case FunctionContext::WindowFunction: return "window"; |
| 36 | + case FunctionContext::Nested: return "nested"; |
| 37 | + default: return "unknown"; |
| 38 | + } |
| 39 | +} |
| 40 | + |
| 41 | +struct ParseFunctionsState : public GlobalTableFunctionState { |
| 42 | + idx_t row = 0; |
| 43 | + vector<FunctionResult> results; |
| 44 | +}; |
| 45 | + |
| 46 | +struct ParseFunctionsBindData : public TableFunctionData { |
| 47 | + string sql; |
| 48 | +}; |
| 49 | + |
| 50 | +// BIND function: runs during query planning to decide output schema |
| 51 | +static unique_ptr<FunctionData> ParseFunctionsBind(ClientContext &context, |
| 52 | + TableFunctionBindInput &input, |
| 53 | + vector<LogicalType> &return_types, |
| 54 | + vector<string> &names) { |
| 55 | + |
| 56 | + string sql_input = StringValue::Get(input.inputs[0]); |
| 57 | + |
| 58 | + // always return the same columns: |
| 59 | + return_types = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}; |
| 60 | + // function name, schema name, usage context |
| 61 | + names = {"function_name", "schema", "context"}; |
| 62 | + |
| 63 | + // create a bind data object to hold the SQL input |
| 64 | + auto result = make_uniq<ParseFunctionsBindData>(); |
| 65 | + result->sql = sql_input; |
| 66 | + |
| 67 | + return std::move(result); |
| 68 | +} |
| 69 | + |
| 70 | +// INIT function: runs before table function execution |
| 71 | +static unique_ptr<GlobalTableFunctionState> ParseFunctionsInit(ClientContext &context, |
| 72 | + TableFunctionInitInput &input) { |
| 73 | + return make_uniq<ParseFunctionsState>(); |
| 74 | +} |
| 75 | + |
| 76 | +class FunctionExtractor { |
| 77 | +public: |
| 78 | + static void ExtractFromExpression(const ParsedExpression &expr, |
| 79 | + std::vector<FunctionResult> &results, |
| 80 | + FunctionContext context = FunctionContext::Select) { |
| 81 | + if (expr.expression_class == ExpressionClass::FUNCTION) { |
| 82 | + auto &func = (FunctionExpression &)expr; |
| 83 | + results.push_back(FunctionResult{ |
| 84 | + func.function_name, |
| 85 | + func.schema.empty() ? "main" : func.schema, |
| 86 | + ToString(context) |
| 87 | + }); |
| 88 | + |
| 89 | + // For nested function calls within this function, mark as nested |
| 90 | + ParsedExpressionIterator::EnumerateChildren(expr, [&](const ParsedExpression &child) { |
| 91 | + ExtractFromExpression(child, results, FunctionContext::Nested); |
| 92 | + }); |
| 93 | + } else if (expr.expression_class == ExpressionClass::WINDOW) { |
| 94 | + auto &window_expr = (WindowExpression &)expr; |
| 95 | + results.push_back(FunctionResult{ |
| 96 | + window_expr.function_name, |
| 97 | + window_expr.schema.empty() ? "main" : window_expr.schema, |
| 98 | + ToString(context) |
| 99 | + }); |
| 100 | + |
| 101 | + // Extract functions from window function arguments |
| 102 | + for (const auto &child : window_expr.children) { |
| 103 | + if (child) { |
| 104 | + ExtractFromExpression(*child, results, FunctionContext::Nested); |
| 105 | + } |
| 106 | + } |
| 107 | + |
| 108 | + // Extract functions from PARTITION BY expressions |
| 109 | + for (const auto &partition : window_expr.partitions) { |
| 110 | + if (partition) { |
| 111 | + ExtractFromExpression(*partition, results, FunctionContext::Nested); |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + // Extract functions from ORDER BY expressions |
| 116 | + for (const auto &order : window_expr.orders) { |
| 117 | + if (order.expression) { |
| 118 | + ExtractFromExpression(*order.expression, results, FunctionContext::Nested); |
| 119 | + } |
| 120 | + } |
| 121 | + |
| 122 | + // Extract functions from argument ordering expressions |
| 123 | + for (const auto &arg_order : window_expr.arg_orders) { |
| 124 | + if (arg_order.expression) { |
| 125 | + ExtractFromExpression(*arg_order.expression, results, FunctionContext::Nested); |
| 126 | + } |
| 127 | + } |
| 128 | + |
| 129 | + // Extract functions from frame expressions |
| 130 | + if (window_expr.start_expr) { |
| 131 | + ExtractFromExpression(*window_expr.start_expr, results, FunctionContext::Nested); |
| 132 | + } |
| 133 | + if (window_expr.end_expr) { |
| 134 | + ExtractFromExpression(*window_expr.end_expr, results, FunctionContext::Nested); |
| 135 | + } |
| 136 | + if (window_expr.offset_expr) { |
| 137 | + ExtractFromExpression(*window_expr.offset_expr, results, FunctionContext::Nested); |
| 138 | + } |
| 139 | + if (window_expr.default_expr) { |
| 140 | + ExtractFromExpression(*window_expr.default_expr, results, FunctionContext::Nested); |
| 141 | + } |
| 142 | + |
| 143 | + // Extract functions from filter expression |
| 144 | + if (window_expr.filter_expr) { |
| 145 | + ExtractFromExpression(*window_expr.filter_expr, results, FunctionContext::Nested); |
| 146 | + } |
| 147 | + } else { |
| 148 | + // For non-function expressions, preserve the current context |
| 149 | + ParsedExpressionIterator::EnumerateChildren(expr, [&](const ParsedExpression &child) { |
| 150 | + ExtractFromExpression(child, results, context); |
| 151 | + }); |
| 152 | + } |
| 153 | + } |
| 154 | + |
| 155 | + static void ExtractFromExpressionList(const vector<unique_ptr<ParsedExpression>> &expressions, |
| 156 | + std::vector<FunctionResult> &results, |
| 157 | + FunctionContext context) { |
| 158 | + for (const auto &expr : expressions) { |
| 159 | + if (expr) { |
| 160 | + ExtractFromExpression(*expr, results, context); |
| 161 | + } |
| 162 | + } |
| 163 | + } |
| 164 | +}; |
| 165 | + |
| 166 | + |
| 167 | +static void ExtractFunctionsFromQueryNode(const QueryNode &node, std::vector<FunctionResult> &results) { |
| 168 | + if (node.type == QueryNodeType::SELECT_NODE) { |
| 169 | + auto &select_node = (SelectNode &)node; |
| 170 | + |
| 171 | + // Extract from CTEs first (to match expected order in tests) |
| 172 | + for (const auto &cte : select_node.cte_map.map) { |
| 173 | + if (cte.second && cte.second->query && cte.second->query->node) { |
| 174 | + ExtractFunctionsFromQueryNode(*cte.second->query->node, results); |
| 175 | + } |
| 176 | + } |
| 177 | + |
| 178 | + // Extract from SELECT list |
| 179 | + FunctionExtractor::ExtractFromExpressionList(select_node.select_list, results, FunctionContext::Select); |
| 180 | + |
| 181 | + // Extract from WHERE clause |
| 182 | + if (select_node.where_clause) { |
| 183 | + FunctionExtractor::ExtractFromExpression(*select_node.where_clause, results, FunctionContext::Where); |
| 184 | + } |
| 185 | + |
| 186 | + // Extract from GROUP BY clause |
| 187 | + FunctionExtractor::ExtractFromExpressionList(select_node.groups.group_expressions, results, FunctionContext::GroupBy); |
| 188 | + |
| 189 | + // Extract from HAVING clause |
| 190 | + if (select_node.having) { |
| 191 | + FunctionExtractor::ExtractFromExpression(*select_node.having, results, FunctionContext::Having); |
| 192 | + } |
| 193 | + |
| 194 | + // Extract from ORDER BY clause |
| 195 | + for (const auto &modifier : select_node.modifiers) { |
| 196 | + if (modifier->type == ResultModifierType::ORDER_MODIFIER) { |
| 197 | + auto &order_modifier = (OrderModifier &)*modifier; |
| 198 | + for (const auto &order : order_modifier.orders) { |
| 199 | + if (order.expression) { |
| 200 | + FunctionExtractor::ExtractFromExpression(*order.expression, results, FunctionContext::OrderBy); |
| 201 | + } |
| 202 | + } |
| 203 | + } |
| 204 | + } |
| 205 | + } |
| 206 | +} |
| 207 | + |
| 208 | +static void ExtractFunctionsFromSQL(const std::string &sql, std::vector<FunctionResult> &results) { |
| 209 | + Parser parser; |
| 210 | + |
| 211 | + try { |
| 212 | + parser.ParseQuery(sql); |
| 213 | + } catch (const ParserException &ex) { |
| 214 | + // swallow parser exceptions to make this function more robust. is_parsable can be used if needed |
| 215 | + return; |
| 216 | + } |
| 217 | + |
| 218 | + for (auto &stmt : parser.statements) { |
| 219 | + if (stmt->type == StatementType::SELECT_STATEMENT) { |
| 220 | + auto &select_stmt = (SelectStatement &)*stmt; |
| 221 | + if (select_stmt.node) { |
| 222 | + ExtractFunctionsFromQueryNode(*select_stmt.node, results); |
| 223 | + } |
| 224 | + } |
| 225 | + } |
| 226 | +} |
| 227 | + |
| 228 | +static void ParseFunctionsFunction(ClientContext &context, |
| 229 | + TableFunctionInput &data, |
| 230 | + DataChunk &output) { |
| 231 | + auto &state = (ParseFunctionsState &)*data.global_state; |
| 232 | + auto &bind_data = (ParseFunctionsBindData &)*data.bind_data; |
| 233 | + |
| 234 | + if (state.results.empty() && state.row == 0) { |
| 235 | + ExtractFunctionsFromSQL(bind_data.sql, state.results); |
| 236 | + } |
| 237 | + |
| 238 | + if (state.row >= state.results.size()) { |
| 239 | + return; |
| 240 | + } |
| 241 | + |
| 242 | + auto &func = state.results[state.row]; |
| 243 | + output.SetCardinality(1); |
| 244 | + output.SetValue(0, 0, Value(func.function_name)); |
| 245 | + output.SetValue(1, 0, Value(func.schema)); |
| 246 | + output.SetValue(2, 0, Value(func.context)); |
| 247 | + |
| 248 | + state.row++; |
| 249 | +} |
| 250 | + |
| 251 | +static void ParseFunctionNamesScalarFunction(DataChunk &args, ExpressionState &state, Vector &result) { |
| 252 | + UnaryExecutor::Execute<string_t, list_entry_t>(args.data[0], result, args.size(), |
| 253 | + [&result](string_t query) -> list_entry_t { |
| 254 | + // Parse the SQL query and extract function names |
| 255 | + auto query_string = query.GetString(); |
| 256 | + std::vector<FunctionResult> parsed_functions; |
| 257 | + ExtractFunctionsFromSQL(query_string, parsed_functions); |
| 258 | + |
| 259 | + auto current_size = ListVector::GetListSize(result); |
| 260 | + auto number_of_functions = parsed_functions.size(); |
| 261 | + auto new_size = current_size + number_of_functions; |
| 262 | + |
| 263 | + // grow list if needed |
| 264 | + if (ListVector::GetListCapacity(result) < new_size) { |
| 265 | + ListVector::Reserve(result, new_size); |
| 266 | + } |
| 267 | + |
| 268 | + // Write the function names into the child vector |
| 269 | + auto functions = FlatVector::GetData<string_t>(ListVector::GetEntry(result)); |
| 270 | + for (size_t i = 0; i < parsed_functions.size(); i++) { |
| 271 | + auto &func = parsed_functions[i]; |
| 272 | + functions[current_size + i] = StringVector::AddStringOrBlob(ListVector::GetEntry(result), func.function_name); |
| 273 | + } |
| 274 | + |
| 275 | + // Update size |
| 276 | + ListVector::SetListSize(result, new_size); |
| 277 | + |
| 278 | + return list_entry_t(current_size, number_of_functions); |
| 279 | + }); |
| 280 | +} |
| 281 | + |
| 282 | +static void ParseFunctionsScalarFunction_struct(DataChunk &args, ExpressionState &state, Vector &result) { |
| 283 | + UnaryExecutor::Execute<string_t, list_entry_t>(args.data[0], result, args.size(), |
| 284 | + [&result](string_t query) -> list_entry_t { |
| 285 | + // Parse the SQL query and extract function names |
| 286 | + auto query_string = query.GetString(); |
| 287 | + std::vector<FunctionResult> parsed_functions; |
| 288 | + ExtractFunctionsFromSQL(query_string, parsed_functions); |
| 289 | + |
| 290 | + auto current_size = ListVector::GetListSize(result); |
| 291 | + auto number_of_functions = parsed_functions.size(); |
| 292 | + auto new_size = current_size + number_of_functions; |
| 293 | + |
| 294 | + // Grow list vector if needed |
| 295 | + if (ListVector::GetListCapacity(result) < new_size) { |
| 296 | + ListVector::Reserve(result, new_size); |
| 297 | + } |
| 298 | + |
| 299 | + // Get the struct child vector of the list |
| 300 | + auto &struct_vector = ListVector::GetEntry(result); |
| 301 | + |
| 302 | + // Ensure list size is updated |
| 303 | + ListVector::SetListSize(result, new_size); |
| 304 | + |
| 305 | + // Get the fields in the STRUCT |
| 306 | + auto &entries = StructVector::GetEntries(struct_vector); |
| 307 | + auto &function_name_entry = *entries[0]; // "function_name" field |
| 308 | + auto &schema_entry = *entries[1]; // "schema" field |
| 309 | + auto &context_entry = *entries[2]; // "context" field |
| 310 | + |
| 311 | + auto function_name_data = FlatVector::GetData<string_t>(function_name_entry); |
| 312 | + auto schema_data = FlatVector::GetData<string_t>(schema_entry); |
| 313 | + auto context_data = FlatVector::GetData<string_t>(context_entry); |
| 314 | + |
| 315 | + for (size_t i = 0; i < number_of_functions; i++) { |
| 316 | + const auto &func = parsed_functions[i]; |
| 317 | + auto idx = current_size + i; |
| 318 | + |
| 319 | + function_name_data[idx] = StringVector::AddStringOrBlob(function_name_entry, func.function_name); |
| 320 | + schema_data[idx] = StringVector::AddStringOrBlob(schema_entry, func.schema); |
| 321 | + context_data[idx] = StringVector::AddStringOrBlob(context_entry, func.context); |
| 322 | + } |
| 323 | + |
| 324 | + return list_entry_t(current_size, number_of_functions); |
| 325 | + }); |
| 326 | +} |
| 327 | + |
| 328 | +// Extension scaffolding |
| 329 | +// --------------------------------------------------- |
| 330 | + |
| 331 | +void RegisterParseFunctionsFunction(DatabaseInstance &db) { |
| 332 | + TableFunction tf("parse_functions", {LogicalType::VARCHAR}, ParseFunctionsFunction, ParseFunctionsBind, ParseFunctionsInit); |
| 333 | + ExtensionUtil::RegisterFunction(db, tf); |
| 334 | +} |
| 335 | + |
| 336 | +void RegisterParseFunctionScalarFunction(DatabaseInstance &db) { |
| 337 | + // parse_function_names is a scalar function that returns a list of function names |
| 338 | + ScalarFunction sf("parse_function_names", {LogicalType::VARCHAR}, LogicalType::LIST(LogicalType::VARCHAR), ParseFunctionNamesScalarFunction); |
| 339 | + ExtensionUtil::RegisterFunction(db, sf); |
| 340 | + |
| 341 | + // parse_functions_struct is a scalar function that returns a list of structs |
| 342 | + auto return_type = LogicalType::LIST(LogicalType::STRUCT({ |
| 343 | + {"function_name", LogicalType::VARCHAR}, |
| 344 | + {"schema", LogicalType::VARCHAR}, |
| 345 | + {"context", LogicalType::VARCHAR} |
| 346 | + })); |
| 347 | + ScalarFunction sf_struct("parse_functions", {LogicalType::VARCHAR}, return_type, ParseFunctionsScalarFunction_struct); |
| 348 | + ExtensionUtil::RegisterFunction(db, sf_struct); |
| 349 | +} |
| 350 | + |
| 351 | + |
| 352 | + |
| 353 | +} // namespace duckdb |
0 commit comments