Skip to content

Commit d313cda

Browse files
committed
Update to use enum for table context
1 parent 6c81ffc commit d313cda

File tree

1 file changed

+37
-13
lines changed

1 file changed

+37
-13
lines changed

src/parse_tables_extension.cpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,34 @@
1717

1818
namespace duckdb {
1919

20+
/**
21+
* Represents where a table is used in a query.
22+
*/
23+
enum class TableContext {
24+
From, // table in from clause
25+
JoinLeft, // table in left side of a join
26+
JoinRight, // table in right side of a join
27+
FromCTE, // table in from clause that references a CTE
28+
CTE, // table is defined as a CTE
29+
Subquery // table in a subquery
30+
};
31+
32+
inline const char *ToString(TableContext context) {
33+
switch (context) {
34+
case TableContext::From: return "from";
35+
case TableContext::JoinLeft: return "join_left";
36+
case TableContext::JoinRight: return "join_right";
37+
case TableContext::FromCTE: return "from_cte";
38+
case TableContext::CTE: return "cte";
39+
case TableContext::Subquery: return "subquery";
40+
default: return "unknown";
41+
}
42+
}
43+
2044
struct TableRefResult {
2145
string schema;
2246
string table;
23-
string context;
47+
TableContext context;
2448
};
2549

2650
struct ParseTablesState : public GlobalTableFunctionState {
@@ -63,14 +87,14 @@ static unique_ptr<GlobalTableFunctionState> MyInit(ClientContext &context,
6387
static void ExtractTablesFromQueryNode(
6488
const duckdb::QueryNode &node,
6589
std::vector<TableRefResult> &results,
66-
const std::string &context = "from",
90+
const TableContext context = TableContext::From,
6791
const duckdb::CommonTableExpressionMap *cte_map = nullptr
6892
);
6993

7094
static void ExtractTablesFromRef(
7195
const duckdb::TableRef &ref,
7296
std::vector<TableRefResult> &results,
73-
const std::string &context = "from",
97+
const TableContext context = TableContext::From,
7498
bool is_top_level = false,
7599
const duckdb::CommonTableExpressionMap *cte_map = nullptr
76100
) {
@@ -79,12 +103,12 @@ static void ExtractTablesFromRef(
79103
switch (ref.type) {
80104
case TableReferenceType::BASE_TABLE: {
81105
auto &base = (BaseTableRef &)ref;
82-
std::string context_label = context;
106+
TableContext context_label = context;
83107

84108
if (cte_map && cte_map->map.find(base.table_name) != cte_map->map.end()) {
85-
context_label = "from_cte";
109+
context_label = TableContext::FromCTE;
86110
} else if (is_top_level) {
87-
context_label = "from";
111+
context_label = TableContext::From;
88112
}
89113

90114
results.push_back(TableRefResult{
@@ -96,14 +120,14 @@ static void ExtractTablesFromRef(
96120
}
97121
case TableReferenceType::JOIN: {
98122
auto &join = (JoinRef &)ref;
99-
ExtractTablesFromRef(*join.left, results, "join_left", is_top_level, cte_map);
100-
ExtractTablesFromRef(*join.right, results, "join_right", false, cte_map);
123+
ExtractTablesFromRef(*join.left, results, TableContext::JoinLeft, is_top_level, cte_map);
124+
ExtractTablesFromRef(*join.right, results, TableContext::JoinRight, false, cte_map);
101125
break;
102126
}
103127
case TableReferenceType::SUBQUERY: {
104128
auto &subquery = (SubqueryRef &)ref;
105129
if (subquery.subquery && subquery.subquery->node) {
106-
ExtractTablesFromQueryNode(*subquery.subquery->node, results, "subquery", cte_map);
130+
ExtractTablesFromQueryNode(*subquery.subquery->node, results, TableContext::Subquery, cte_map);
107131
}
108132
break;
109133
}
@@ -116,7 +140,7 @@ static void ExtractTablesFromRef(
116140
static void ExtractTablesFromQueryNode(
117141
const duckdb::QueryNode &node,
118142
std::vector<TableRefResult> &results,
119-
const std::string &context,
143+
const TableContext context,
120144
const duckdb::CommonTableExpressionMap *cte_map
121145
) {
122146
using namespace duckdb;
@@ -127,11 +151,11 @@ static void ExtractTablesFromQueryNode(
127151
// Emit CTE definitions
128152
for (const auto &entry : select_node.cte_map.map) {
129153
results.push_back(TableRefResult{
130-
"", entry.first, "cte"
154+
"", entry.first, TableContext::CTE
131155
});
132156

133157
if (entry.second && entry.second->query && entry.second->query->node) {
134-
ExtractTablesFromQueryNode(*entry.second->query->node, results, "from", &select_node.cte_map);
158+
ExtractTablesFromQueryNode(*entry.second->query->node, results, TableContext::From, &select_node.cte_map);
135159
}
136160
}
137161

@@ -177,7 +201,7 @@ static void MyFunc(ClientContext &context,
177201
output.SetCardinality(1);
178202
output.SetValue(0, 0, Value(ref.schema));
179203
output.SetValue(1, 0, Value(ref.table));
180-
output.SetValue(2, 0, Value(ref.context));
204+
output.SetValue(2, 0, Value(ToString(ref.context)));
181205

182206
state.row++;
183207
}

0 commit comments

Comments
 (0)