diff --git a/.github/workflows/MainDistributionPipeline.yml b/.github/workflows/MainDistributionPipeline.yml index 78990fa..9abe192 100644 --- a/.github/workflows/MainDistributionPipeline.yml +++ b/.github/workflows/MainDistributionPipeline.yml @@ -12,18 +12,18 @@ concurrency: cancel-in-progress: true jobs: - duckdb-next-build: - name: Build extension binaries - uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@main - with: - duckdb_version: main - ci_tools_version: main - extension_name: parser_tools +# duckdb-next-build: +# name: Build extension binaries +# uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@main +# with: +# duckdb_version: main +# ci_tools_version: main +# extension_name: parser_tools duckdb-stable-build: name: Build extension binaries - uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@v1.2.1 + uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@v1.3.0 with: - duckdb_version: v1.2.1 - ci_tools_version: v1.2.1 - extension_name: parser_tools \ No newline at end of file + duckdb_version: v1.3.0 + ci_tools_version: v1.3.0 + extension_name: parser_tools diff --git a/.gitmodules b/.gitmodules index 01b02cb..86e8a7e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,4 +5,4 @@ [submodule "extension-ci-tools"] path = extension-ci-tools url = https://github.com/duckdb/extension-ci-tools - branch = main \ No newline at end of file + branch = main diff --git a/CMakeLists.txt b/CMakeLists.txt index ab816e8..fc02269 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,6 +12,7 @@ include_directories(src/include) set(EXTENSION_SOURCES src/parser_tools_extension.cpp src/parse_tables.cpp + src/parse_where.cpp ) build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES}) diff --git a/duckdb b/duckdb index 8e52ec4..71c5c07 160000 --- a/duckdb +++ b/duckdb @@ -1 +1 @@ -Subproject commit 8e52ec43959ab363643d63cb78ee214577111da4 +Subproject commit 71c5c07cdd295e9409c0505885033ae9eb6b5ddd diff --git a/extension-ci-tools b/extension-ci-tools index 58970c5..71d2002 160000 --- a/extension-ci-tools +++ b/extension-ci-tools @@ -1 +1 @@ -Subproject commit 58970c538d35919db875096460c05806056f4de0 +Subproject commit 71d20029c5314dfc34f3bbdab808b9bce03b8003 diff --git a/src/include/parse_where.hpp b/src/include/parse_where.hpp new file mode 100644 index 0000000..bb32aa3 --- /dev/null +++ b/src/include/parse_where.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include "duckdb.hpp" +#include +#include + +namespace duckdb { + +// Forward declarations +class DatabaseInstance; + +struct WhereConditionResult { + std::string condition; + std::string table_name; // The table this condition applies to (if determinable) + std::string context; // The context where this condition appears (WHERE, HAVING, etc.) +}; + +struct DetailedWhereConditionResult { + std::string column_name; // The column being compared + std::string operator_type; // The comparison operator (>, <, =, etc.) + std::string value; // The value being compared against + std::string table_name; // The table this condition applies to (if determinable) + std::string context; // The context where this condition appears (WHERE, HAVING, etc.) +}; + +void RegisterParseWhereFunction(DatabaseInstance &db); +void RegisterParseWhereScalarFunction(DatabaseInstance &db); +void RegisterParseWhereDetailedFunction(DatabaseInstance &db); + +} // namespace duckdb \ No newline at end of file diff --git a/src/parse_where.cpp b/src/parse_where.cpp new file mode 100644 index 0000000..7f5ea41 --- /dev/null +++ b/src/parse_where.cpp @@ -0,0 +1,484 @@ +#include "parse_where.hpp" +#include "duckdb.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/conjunction_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/parser/expression/window_expression.hpp" +#include "duckdb/parser/expression/case_expression.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/between_expression.hpp" +#include "duckdb/parser/expression/lambda_expression.hpp" +#include "duckdb/parser/expression/positional_reference_expression.hpp" +#include "duckdb/parser/expression/parameter_expression.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/main/extension_util.hpp" + +namespace duckdb { + +struct ParseWhereState : public GlobalTableFunctionState { + idx_t row = 0; + vector results; +}; + +struct ParseWhereBindData : public TableFunctionData { + string sql; +}; + +static unique_ptr ParseWhereBind(ClientContext &context, + TableFunctionBindInput &input, + vector &return_types, + vector &names) { + string sql_input = StringValue::Get(input.inputs[0]); + + return_types = { + LogicalType::VARCHAR, // condition + LogicalType::VARCHAR, // table_name + LogicalType::VARCHAR // context + }; + + names = {"condition", "table_name", "context"}; + + auto result = make_uniq(); + result->sql = sql_input; + + return std::move(result); +} + +static unique_ptr ParseWhereInit(ClientContext &context, + TableFunctionInitInput &input) { + return make_uniq(); +} + +static string ExpressionToString(const ParsedExpression &expr) { + return expr.ToString(); +} + +static void ExtractWhereConditionsFromExpression( + const ParsedExpression &expr, + vector &results, + const string &context = "WHERE", + const string &table_name = "" +) { + if (expr.type == ExpressionType::INVALID) return; + + switch (expr.GetExpressionClass()) { + case ExpressionClass::CONJUNCTION: { + auto &conj = (ConjunctionExpression &)expr; + for (auto &child : conj.children) { + ExtractWhereConditionsFromExpression(*child, results, context, table_name); + } + break; + } + case ExpressionClass::COMPARISON: { + auto &comp = (ComparisonExpression &)expr; + results.push_back(WhereConditionResult{ + ExpressionToString(comp), + table_name, + context + }); + break; + } + case ExpressionClass::OPERATOR: { + auto &op = (OperatorExpression &)expr; + results.push_back(WhereConditionResult{ + ExpressionToString(op), + table_name, + context + }); + break; + } + case ExpressionClass::FUNCTION: { + auto &func = (FunctionExpression &)expr; + results.push_back(WhereConditionResult{ + ExpressionToString(func), + table_name, + context + }); + break; + } + case ExpressionClass::BETWEEN: { + auto &between = (BetweenExpression &)expr; + results.push_back(WhereConditionResult{ + ExpressionToString(between), + table_name, + context + }); + break; + } + case ExpressionClass::CASE: { + auto &case_expr = (CaseExpression &)expr; + results.push_back(WhereConditionResult{ + ExpressionToString(case_expr), + table_name, + context + }); + break; + } + default: + break; + } +} + +static void ExtractWhereConditionsFromQueryNode( + const QueryNode &node, + vector &results +) { + if (node.type == QueryNodeType::SELECT_NODE) { + auto &select_node = (SelectNode &)node; + string table_name = "(empty)"; // Default table name + + // Extract table name from FROM clause + if (select_node.from_table) { + if (select_node.from_table->type == TableReferenceType::BASE_TABLE) { + auto &base = (BaseTableRef &)*select_node.from_table; + table_name = base.table_name; + } + } + + // Extract WHERE conditions + if (select_node.where_clause) { + ExtractWhereConditionsFromExpression(*select_node.where_clause, results, "WHERE", table_name); + } + + // Extract HAVING conditions + if (select_node.having) { + ExtractWhereConditionsFromExpression(*select_node.having, results, "HAVING", table_name); + } + } +} + +static void ExtractWhereConditionsFromSQL(const string &sql, vector &results) { + Parser parser; + + try { + parser.ParseQuery(sql); + } catch (const ParserException &ex) { + return; + } + + for (auto &stmt : parser.statements) { + if (stmt->type == StatementType::SELECT_STATEMENT) { + auto &select_stmt = (SelectStatement &)*stmt; + if (select_stmt.node) { + ExtractWhereConditionsFromQueryNode(*select_stmt.node, results); + } + } + } +} + +static void ParseWhereFunction(ClientContext &context, + TableFunctionInput &data, + DataChunk &output) { + auto &state = (ParseWhereState &)*data.global_state; + auto &bind_data = (ParseWhereBindData &)*data.bind_data; + + if (state.results.empty() && state.row == 0) { + ExtractWhereConditionsFromSQL(bind_data.sql, state.results); + } + + if (state.row >= state.results.size()) { + return; + } + + auto &result = state.results[state.row]; + output.SetCardinality(1); + output.SetValue(0, 0, Value(result.condition)); + output.SetValue(1, 0, Value(result.table_name)); + output.SetValue(2, 0, Value(result.context)); + + state.row++; +} + +static void ParseWhereScalarFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(args.data[0], result, args.size(), + [&result](string_t query) -> list_entry_t { + auto query_string = query.GetString(); + vector conditions; + ExtractWhereConditionsFromSQL(query_string, conditions); + + auto current_size = ListVector::GetListSize(result); + auto number_of_conditions = conditions.size(); + auto new_size = current_size + number_of_conditions; + + if (ListVector::GetListCapacity(result) < new_size) { + ListVector::Reserve(result, new_size); + } + + auto &struct_vector = ListVector::GetEntry(result); + auto &entries = StructVector::GetEntries(struct_vector); + auto &condition_entry = *entries[0]; + auto &table_entry = *entries[1]; + auto &context_entry = *entries[2]; + + auto condition_data = FlatVector::GetData(condition_entry); + auto table_data = FlatVector::GetData(table_entry); + auto context_data = FlatVector::GetData(context_entry); + + for (size_t i = 0; i < number_of_conditions; i++) { + const auto &condition = conditions[i]; + auto idx = current_size + i; + + condition_data[idx] = StringVector::AddStringOrBlob(condition_entry, condition.condition); + table_data[idx] = StringVector::AddStringOrBlob(table_entry, condition.table_name); + context_data[idx] = StringVector::AddStringOrBlob(context_entry, condition.context); + } + + ListVector::SetListSize(result, new_size); + return list_entry_t(current_size, number_of_conditions); + }); +} + +void RegisterParseWhereFunction(DatabaseInstance &db) { + TableFunction tf("parse_where", {LogicalType::VARCHAR}, ParseWhereFunction, ParseWhereBind, ParseWhereInit); + ExtensionUtil::RegisterFunction(db, tf); +} + +void RegisterParseWhereScalarFunction(DatabaseInstance &db) { + auto return_type = LogicalType::LIST(LogicalType::STRUCT({ + {"condition", LogicalType::VARCHAR}, + {"table_name", LogicalType::VARCHAR}, + {"context", LogicalType::VARCHAR} + })); + ScalarFunction sf("parse_where", {LogicalType::VARCHAR}, return_type, ParseWhereScalarFunction); + ExtensionUtil::RegisterFunction(db, sf); +} + +static string DetailedExpressionTypeToOperator(ExpressionType type) { + switch (type) { + case ExpressionType::COMPARE_EQUAL: + return "="; + case ExpressionType::COMPARE_NOTEQUAL: + return "!="; + case ExpressionType::COMPARE_LESSTHAN: + return "<"; + case ExpressionType::COMPARE_GREATERTHAN: + return ">"; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return "<="; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return ">="; + case ExpressionType::COMPARE_DISTINCT_FROM: + return "IS DISTINCT FROM"; + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + return "IS NOT DISTINCT FROM"; + default: + return "UNKNOWN"; + } +} + +static void ExtractDetailedWhereConditionsFromExpression( + const ParsedExpression &expr, + vector &results, + const string &context = "WHERE", + const string &table_name = "" +) { + if (expr.type == ExpressionType::INVALID) return; + + switch (expr.GetExpressionClass()) { + case ExpressionClass::CONJUNCTION: { + auto &conj = (ConjunctionExpression &)expr; + for (auto &child : conj.children) { + ExtractDetailedWhereConditionsFromExpression(*child, results, context, table_name); + } + break; + } + case ExpressionClass::COMPARISON: { + auto &comp = (ComparisonExpression &)expr; + DetailedWhereConditionResult result; + result.context = context; + result.table_name = table_name; + + // Extract column name + if (comp.left->GetExpressionClass() == ExpressionClass::COLUMN_REF) { + auto &col_ref = (ColumnRefExpression &)*comp.left; + result.column_name = col_ref.GetColumnName(); + } + + // Extract operator + result.operator_type = DetailedExpressionTypeToOperator(comp.type); + + // Extract value + if (comp.right->GetExpressionClass() == ExpressionClass::CONSTANT) { + auto &const_expr = (ConstantExpression &)*comp.right; + result.value = const_expr.value.ToString(); + } else { + result.value = comp.right->ToString(); + } + + results.push_back(result); + break; + } + case ExpressionClass::BETWEEN: { + auto &between = (BetweenExpression &)expr; + DetailedWhereConditionResult result; + result.context = context; + result.table_name = table_name; + + // Extract column name + if (between.input->GetExpressionClass() == ExpressionClass::COLUMN_REF) { + auto &col_ref = (ColumnRefExpression &)*between.input; + result.column_name = col_ref.GetColumnName(); + } + + // For BETWEEN, we'll create two conditions: >= lower AND <= upper + result.operator_type = ">="; + if (between.lower->GetExpressionClass() == ExpressionClass::CONSTANT) { + auto &const_expr = (ConstantExpression &)*between.lower; + result.value = const_expr.value.ToString(); + } else { + result.value = between.lower->ToString(); + } + results.push_back(result); + + // Add the upper bound condition + DetailedWhereConditionResult upper_result = result; + upper_result.operator_type = "<="; + if (between.upper->GetExpressionClass() == ExpressionClass::CONSTANT) { + auto &const_expr = (ConstantExpression &)*between.upper; + upper_result.value = const_expr.value.ToString(); + } else { + upper_result.value = between.upper->ToString(); + } + results.push_back(upper_result); + break; + } + case ExpressionClass::OPERATOR: { + auto &op = (OperatorExpression &)expr; + if (op.children.size() >= 2) { + DetailedWhereConditionResult result; + result.context = context; + result.table_name = table_name; + + // Extract column name + if (op.children[0]->GetExpressionClass() == ExpressionClass::COLUMN_REF) { + auto &col_ref = (ColumnRefExpression &)*op.children[0]; + result.column_name = col_ref.GetColumnName(); + } + + // Extract operator + result.operator_type = DetailedExpressionTypeToOperator(op.type); + + // Extract value + if (op.children[1]->GetExpressionClass() == ExpressionClass::CONSTANT) { + auto &const_expr = (ConstantExpression &)*op.children[1]; + result.value = const_expr.value.ToString(); + } else { + result.value = op.children[1]->ToString(); + } + + results.push_back(result); + } + break; + } + default: + break; + } +} + +struct ParseWhereDetailedState : public GlobalTableFunctionState { + idx_t row = 0; + vector results; +}; + +struct ParseWhereDetailedBindData : public TableFunctionData { + string sql; +}; + +static unique_ptr ParseWhereDetailedBind(ClientContext &context, + TableFunctionBindInput &input, + vector &return_types, + vector &names) { + string sql_input = StringValue::Get(input.inputs[0]); + + return_types = { + LogicalType::VARCHAR, // column_name + LogicalType::VARCHAR, // operator_type + LogicalType::VARCHAR, // value + LogicalType::VARCHAR, // table_name + LogicalType::VARCHAR // context + }; + + names = {"column_name", "operator_type", "value", "table_name", "context"}; + + auto result = make_uniq(); + result->sql = sql_input; + + return std::move(result); +} + +static unique_ptr ParseWhereDetailedInit(ClientContext &context, + TableFunctionInitInput &input) { + return make_uniq(); +} + +static void ParseWhereDetailedFunction(ClientContext &context, + TableFunctionInput &data, + DataChunk &output) { + auto &state = (ParseWhereDetailedState &)*data.global_state; + auto &bind_data = (ParseWhereDetailedBindData &)*data.bind_data; + + if (state.results.empty() && state.row == 0) { + Parser parser; + try { + parser.ParseQuery(bind_data.sql); + } catch (const ParserException &ex) { + return; + } + + for (auto &stmt : parser.statements) { + if (stmt->type == StatementType::SELECT_STATEMENT) { + auto &select_stmt = (SelectStatement &)*stmt; + if (select_stmt.node) { + if (select_stmt.node->type == QueryNodeType::SELECT_NODE) { + auto &select_node = (SelectNode &)*select_stmt.node; + string table_name = "(empty)"; // Default table name + + // Try to extract table name from FROM clause + if (select_node.from_table) { + if (select_node.from_table->type == TableReferenceType::BASE_TABLE) { + auto &base_table = (BaseTableRef &)*select_node.from_table; + table_name = base_table.table_name; + } + } + + if (select_node.where_clause) { + ExtractDetailedWhereConditionsFromExpression(*select_node.where_clause, state.results, "WHERE", table_name); + } + if (select_node.having) { + ExtractDetailedWhereConditionsFromExpression(*select_node.having, state.results, "HAVING", table_name); + } + } + } + } + } + } + + if (state.row >= state.results.size()) { + return; + } + + auto &result = state.results[state.row]; + output.SetCardinality(1); + output.SetValue(0, 0, Value(result.column_name)); + output.SetValue(1, 0, Value(result.operator_type)); + output.SetValue(2, 0, Value(result.value)); + output.SetValue(3, 0, Value(result.table_name)); + output.SetValue(4, 0, Value(result.context)); + + state.row++; +} + +void RegisterParseWhereDetailedFunction(DatabaseInstance &db) { + TableFunction tf("parse_where_detailed", {LogicalType::VARCHAR}, ParseWhereDetailedFunction, ParseWhereDetailedBind, ParseWhereDetailedInit); + ExtensionUtil::RegisterFunction(db, tf); +} + +} // namespace duckdb diff --git a/src/parser_tools_extension.cpp b/src/parser_tools_extension.cpp index f25370b..c70f8f0 100644 --- a/src/parser_tools_extension.cpp +++ b/src/parser_tools_extension.cpp @@ -2,6 +2,7 @@ #include "parser_tools_extension.hpp" #include "parse_tables.hpp" +#include "parse_where.hpp" #include "duckdb.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/string_util.hpp" @@ -23,6 +24,9 @@ namespace duckdb { static void LoadInternal(DatabaseInstance &instance) { RegisterParseTablesFunction(instance); RegisterParseTableScalarFunction(instance); + RegisterParseWhereFunction(instance); + RegisterParseWhereScalarFunction(instance); + RegisterParseWhereDetailedFunction(instance); } void ParserToolsExtension::Load(DuckDB &db) { diff --git a/test/sql/parse_tools/table_functions/parse_where.test b/test/sql/parse_tools/table_functions/parse_where.test new file mode 100644 index 0000000..c8e19aa --- /dev/null +++ b/test/sql/parse_tools/table_functions/parse_where.test @@ -0,0 +1,103 @@ +# name: test/sql/parser_tools/tables_functions/parse_where.test +# description: test parse_where and parse_where_detailed table functions +# group: [parse_where] + +# Before we load the extension, this will fail +statement error +SELECT * FROM parse_where('SELECT * FROM my_table WHERE x > 1;'); +---- +Catalog Error: Table Function with name parse_where does not exist! + +# Require statement will ensure this test is run with this extension loaded +require parser_tools + +# Simple comparison +query III +SELECT * FROM parse_where('SELECT * FROM my_table WHERE x > 1;'); +---- +(x > 1) my_table WHERE + +# Simple comparison with detailed parser +query IIIII +SELECT * FROM parse_where_detailed('SELECT * FROM my_table WHERE x > 1;'); +---- +x > 1 my_table WHERE + +# Multiple conditions with AND +query III +SELECT * FROM parse_where('SELECT * FROM my_table WHERE x > 1 AND y < 100;'); +---- +(x > 1) my_table WHERE +(y < 100) my_table WHERE + +# Multiple conditions with AND (detailed) +query IIIII +SELECT * FROM parse_where_detailed('SELECT * FROM my_table WHERE x > 1 AND y < 100;'); +---- +x > 1 my_table WHERE +y < 100 my_table WHERE + +# BETWEEN condition +query III +SELECT * FROM parse_where('SELECT * FROM my_table WHERE x BETWEEN 1 AND 100;'); +---- +(x BETWEEN 1 AND 100) my_table WHERE + +# BETWEEN condition (detailed) +query IIIII +SELECT * FROM parse_where_detailed('SELECT * FROM my_table WHERE x BETWEEN 1 AND 100;'); +---- +x >= 1 my_table WHERE +x <= 100 my_table WHERE + +# Complex conditions with AND/OR +query III +SELECT * FROM parse_where('SELECT * FROM my_table WHERE (x > 1 AND y < 100) OR z = 42;'); +---- +(x > 1) my_table WHERE +(y < 100) my_table WHERE +(z = 42) my_table WHERE + +# Complex conditions with AND/OR (detailed) +query IIIII +SELECT * FROM parse_where_detailed('SELECT * FROM my_table WHERE (x > 1 AND y < 100) OR z = 42;'); +---- +x > 1 my_table WHERE +y < 100 my_table WHERE +z = 42 my_table WHERE + +# Multiple operators +query III +SELECT * FROM parse_where('SELECT * FROM my_table WHERE x >= 1 AND x <= 100 AND y != 42;'); +---- +(x >= 1) my_table WHERE +(x <= 100) my_table WHERE +(y != 42) my_table WHERE + +# Multiple operators (detailed) +query IIIII +SELECT * FROM parse_where_detailed('SELECT * FROM my_table WHERE x >= 1 AND x <= 100 AND y != 42;'); +---- +x >= 1 my_table WHERE +x <= 100 my_table WHERE +y != 42 my_table WHERE + +# No WHERE clause +query III +SELECT * FROM parse_where('SELECT * FROM my_table;'); +---- + +# No WHERE clause (detailed) +query IIIII +SELECT * FROM parse_where_detailed('SELECT * FROM my_table;'); +---- + +# Malformed SQL should not error +query III +SELECT * FROM parse_where('SELECT * FROM my_table WHERE'); +---- + +# Malformed SQL should not error (detailed) +query IIIII +SELECT * FROM parse_where_detailed('SELECT * FROM my_table WHERE'); +----