1+ #include " parse_columns.hpp"
2+ #include " duckdb.hpp"
3+ #include " duckdb/parser/parser.hpp"
4+ #include " duckdb/parser/statement/select_statement.hpp"
5+ #include " duckdb/parser/query_node/select_node.hpp"
6+ #include " duckdb/parser/expression/columnref_expression.hpp"
7+ #include " duckdb/parser/parsed_expression_iterator.hpp"
8+ #include " duckdb/parser/result_modifier.hpp"
9+ #include " duckdb/main/extension_util.hpp"
10+
11+ namespace duckdb {
12+
13+ enum class ColumnContext {
14+ Select,
15+ Where,
16+ Having,
17+ OrderBy,
18+ GroupBy,
19+ Join,
20+ FunctionArg,
21+ Window,
22+ Nested
23+ };
24+
25+ inline const char *ToString (ColumnContext context) {
26+ switch (context) {
27+ case ColumnContext::Select: return " select" ;
28+ case ColumnContext::Where: return " where" ;
29+ case ColumnContext::Having: return " having" ;
30+ case ColumnContext::OrderBy: return " order_by" ;
31+ case ColumnContext::GroupBy: return " group_by" ;
32+ case ColumnContext::Join: return " join" ;
33+ case ColumnContext::FunctionArg: return " function_arg" ;
34+ case ColumnContext::Window: return " window" ;
35+ case ColumnContext::Nested: return " nested" ;
36+ default : return " unknown" ;
37+ }
38+ }
39+
40+ struct ParseColumnsState : public GlobalTableFunctionState {
41+ idx_t row = 0 ;
42+ vector<ColumnResult> results;
43+ };
44+
45+ struct ParseColumnsBindData : public TableFunctionData {
46+ string sql;
47+ };
48+
49+ // Helper function to extract schema, table, and column from column_names vector
50+ static void ExtractTableInfo (const vector<string> &column_names,
51+ string &table_schema, string &table_name, string &column_name) {
52+ if (column_names.empty ()) {
53+ return ;
54+ }
55+
56+ // For now, assume simple heuristic:
57+ // - If 3+ elements: first could be schema, second table, third+ column path
58+ // - If 2 elements: first table, second+ column path
59+ // - If 1 element: unqualified column
60+
61+ if (column_names.size () >= 3 ) {
62+ // Assume schema.table.column format
63+ table_schema = column_names[0 ];
64+ table_name = column_names[1 ];
65+ column_name = column_names[2 ];
66+ } else if (column_names.size () == 2 ) {
67+ // Assume table.column format
68+ table_schema = " main" ; // Default schema
69+ table_name = column_names[0 ];
70+ column_name = column_names[1 ];
71+ } else {
72+ // Unqualified column - could be table column or alias
73+ table_schema = " " ; // Will be set to NULL
74+ table_name = " " ; // Will be set to NULL
75+ column_name = column_names[0 ];
76+ }
77+ }
78+
79+ // Helper function to convert vector<string> to a readable expression string
80+ static string VectorToString (const vector<string> &vec) {
81+ if (vec.empty ()) {
82+ return " " ;
83+ }
84+ string result = vec[0 ];
85+ for (size_t i = 1 ; i < vec.size (); i++) {
86+ result += " ." + vec[i];
87+ }
88+ return result;
89+ }
90+
91+ // Helper function to serialize expression_identifiers as JSON-like string
92+ static string SerializeExpressionIdentifiers (const vector<vector<string>> &identifiers) {
93+ if (identifiers.empty ()) {
94+ return " []" ;
95+ }
96+
97+ string result = " [" ;
98+ for (size_t i = 0 ; i < identifiers.size (); i++) {
99+ if (i > 0 ) result += " ," ;
100+ result += " [" ;
101+ for (size_t j = 0 ; j < identifiers[i].size (); j++) {
102+ if (j > 0 ) result += " ," ;
103+ result += " \" " + identifiers[i][j] + " \" " ;
104+ }
105+ result += " ]" ;
106+ }
107+ result += " ]" ;
108+ return result;
109+ }
110+
111+ // Recursive function to extract column references from expressions
112+ static void ExtractFromExpression (const ParsedExpression &expr,
113+ vector<ColumnResult> &results,
114+ ColumnContext context,
115+ const string &selected_name = " " ) {
116+
117+ if (expr.expression_class == ExpressionClass::COLUMN_REF) {
118+ auto &col_ref = (ColumnRefExpression &)expr;
119+
120+ string table_schema, table_name, column_name;
121+ ExtractTableInfo (col_ref.column_names , table_schema, table_name, column_name);
122+
123+ // Convert empty strings to NULLs for consistency
124+ if (table_schema.empty ()) table_schema = " " ;
125+ if (table_name.empty ()) table_name = " " ;
126+
127+ vector<vector<string>> expr_ids = {col_ref.column_names };
128+ results.push_back (ColumnResult{
129+ expr_ids, // expression_identifiers
130+ table_schema.empty () ? " " : table_schema,
131+ table_name.empty () ? " " : table_name,
132+ column_name,
133+ ToString (context),
134+ VectorToString (col_ref.column_names ),
135+ selected_name.empty () ? " " : selected_name
136+ });
137+ } else {
138+ // For non-column expressions, continue traversing to find nested column references
139+ ParsedExpressionIterator::EnumerateChildren (expr, [&](const ParsedExpression &child) {
140+ ExtractFromExpression (child, results, ColumnContext::FunctionArg);
141+ });
142+ }
143+ }
144+
145+ // Helper function to collect all identifiers from an expression recursively
146+ static void CollectExpressionIdentifiers (const ParsedExpression &expr,
147+ vector<vector<string>> &all_identifiers) {
148+ if (expr.expression_class == ExpressionClass::COLUMN_REF) {
149+ auto &col_ref = (ColumnRefExpression &)expr;
150+ all_identifiers.push_back (col_ref.column_names );
151+ } else {
152+ ParsedExpressionIterator::EnumerateChildren (expr, [&](const ParsedExpression &child) {
153+ CollectExpressionIdentifiers (child, all_identifiers);
154+ });
155+ }
156+ }
157+
158+ // Extract columns from SELECT node
159+ static void ExtractFromSelectNode (const SelectNode &select_node, vector<ColumnResult> &results) {
160+
161+ // Extract from SELECT list (output columns)
162+ for (const auto &select_item : select_node.select_list ) {
163+ string selected_name = select_item->alias .empty () ? " " : select_item->alias ;
164+
165+ // If no explicit alias, derive from expression
166+ if (selected_name.empty () && select_item->expression_class == ExpressionClass::COLUMN_REF) {
167+ auto &col_ref = (ColumnRefExpression &)*select_item;
168+ selected_name = col_ref.GetColumnName ();
169+ }
170+
171+ // First extract individual column references
172+ ExtractFromExpression (*select_item, results, ColumnContext::Select);
173+
174+ // Then add the output column entry if it's a complex expression
175+ vector<vector<string>> all_identifiers;
176+ CollectExpressionIdentifiers (*select_item, all_identifiers);
177+
178+ if (all_identifiers.size () > 1 || (all_identifiers.size () == 1 && !select_item->alias .empty ())) {
179+ // Complex expression or aliased column - add output entry
180+ results.push_back (ColumnResult{
181+ all_identifiers,
182+ " " , // table_schema
183+ " " , // table_name
184+ " " , // column_name
185+ ToString (ColumnContext::Select),
186+ select_item->ToString (),
187+ selected_name.empty () ? " " : selected_name
188+ });
189+ }
190+ }
191+
192+ // Extract from WHERE clause
193+ if (select_node.where_clause ) {
194+ ExtractFromExpression (*select_node.where_clause , results, ColumnContext::Where);
195+ }
196+
197+ // Extract from GROUP BY clause
198+ for (const auto &group_expr : select_node.groups .group_expressions ) {
199+ ExtractFromExpression (*group_expr, results, ColumnContext::GroupBy);
200+ }
201+
202+ // Extract from HAVING clause
203+ if (select_node.having ) {
204+ ExtractFromExpression (*select_node.having , results, ColumnContext::Having);
205+ }
206+
207+ // Extract from ORDER BY clause
208+ for (const auto &modifier : select_node.modifiers ) {
209+ if (modifier->type == ResultModifierType::ORDER_MODIFIER) {
210+ auto &order_modifier = (OrderModifier &)*modifier;
211+ for (const auto &order_term : order_modifier.orders ) {
212+ ExtractFromExpression (*order_term.expression , results, ColumnContext::OrderBy);
213+ }
214+ }
215+ }
216+ }
217+
218+ // BIND function: runs during query planning to decide output schema
219+ static unique_ptr<FunctionData> ParseColumnsBind (ClientContext &context, TableFunctionBindInput &input,
220+ vector<LogicalType> &return_types, vector<string> &names) {
221+
222+ string sql_input = StringValue::Get (input.inputs [0 ]);
223+
224+ // Define output schema - simplified for initial implementation
225+ return_types = {
226+ LogicalType::VARCHAR, // expression_identifiers (as JSON-like string for now)
227+ LogicalType::VARCHAR, // table_schema
228+ LogicalType::VARCHAR, // table_name
229+ LogicalType::VARCHAR, // column_name
230+ LogicalType::VARCHAR, // context
231+ LogicalType::VARCHAR, // expression
232+ LogicalType::VARCHAR // selected_name
233+ };
234+
235+ names = {" expression_identifiers" , " table_schema" , " table_name" , " column_name" ,
236+ " context" , " expression" , " selected_name" };
237+
238+ auto result = make_uniq<ParseColumnsBindData>();
239+ result->sql = sql_input;
240+ return std::move (result);
241+ }
242+
243+ // INIT function: runs before table function execution
244+ static unique_ptr<GlobalTableFunctionState> ParseColumnsInit (ClientContext &context,
245+ TableFunctionInitInput &input) {
246+ return make_uniq<ParseColumnsState>();
247+ }
248+
249+ // Main parsing function
250+ static void ParseColumnsFunction (ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
251+ auto &bind_data = (ParseColumnsBindData &)*data_p.bind_data ;
252+ auto &state = (ParseColumnsState &)*data_p.global_state ;
253+
254+ if (state.row == 0 ) {
255+ // Parse the SQL statement
256+ Parser parser;
257+ parser.ParseQuery (bind_data.sql );
258+
259+ if (parser.statements .empty ()) {
260+ return ;
261+ }
262+
263+ // Process each statement
264+ for (const auto &statement : parser.statements ) {
265+ if (statement->type == StatementType::SELECT_STATEMENT) {
266+ auto &select_stmt = (SelectStatement &)*statement;
267+ auto &select_node = (SelectNode &)*select_stmt.node ;
268+ ExtractFromSelectNode (select_node, state.results );
269+ }
270+ }
271+ }
272+
273+ // Output results
274+ idx_t count = 0 ;
275+ while (state.row < state.results .size () && count < STANDARD_VECTOR_SIZE) {
276+ const auto &result = state.results [state.row ];
277+
278+ output.data [0 ].SetValue (count, Value (SerializeExpressionIdentifiers (result.expression_identifiers )));
279+ output.data [1 ].SetValue (count, result.table_schema .empty () ? Value () : Value (result.table_schema ));
280+ output.data [2 ].SetValue (count, result.table_name .empty () ? Value () : Value (result.table_name ));
281+ output.data [3 ].SetValue (count, result.column_name .empty () ? Value () : Value (result.column_name ));
282+ output.data [4 ].SetValue (count, Value (result.context ));
283+ output.data [5 ].SetValue (count, Value (result.expression ));
284+ output.data [6 ].SetValue (count, result.selected_name .empty () ? Value () : Value (result.selected_name ));
285+
286+ state.row ++;
287+ count++;
288+ }
289+
290+ output.SetCardinality (count);
291+ }
292+
293+ void RegisterParseColumnsFunction (DatabaseInstance &db) {
294+ TableFunction parse_columns (" parse_columns" , {LogicalType::VARCHAR}, ParseColumnsFunction, ParseColumnsBind, ParseColumnsInit);
295+ ExtensionUtil::RegisterFunction (db, parse_columns);
296+ }
297+
298+ void RegisterParseColumnScalarFunction (DatabaseInstance &db) {
299+ // TODO: Implement scalar version similar to parse_function_names
300+ }
301+
302+ } // namespace duckdb
0 commit comments