Skip to content

Commit 8588caa

Browse files
committed
fix: cleanups
1 parent 2cae642 commit 8588caa

File tree

4 files changed

+578
-596
lines changed

4 files changed

+578
-596
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ set(EXTENSION_SOURCES
2929
src/adbc_connection.cpp
3030
src/adbc_functions.cpp
3131
src/adbc_scan.cpp
32+
src/adbc_execute.cpp
33+
src/adbc_insert.cpp
3234
src/adbc_catalog.cpp
3335
src/adbc_secrets.cpp
3436
src/adbc_filter_pushdown.cpp

src/adbc_execute.cpp

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#include "adbc_connection.hpp"
2+
#include "duckdb/main/extension/extension_loader.hpp"
3+
#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp"
4+
5+
namespace adbc_scanner {
6+
using namespace duckdb;
7+
8+
// Helper to format error messages with query context
9+
static string FormatError(const string &message, const string &query) {
10+
string result = message;
11+
// Truncate query if too long for error message
12+
if (query.length() > 100) {
13+
result += " [Query: " + query.substr(0, 100) + "...]";
14+
} else {
15+
result += " [Query: " + query + "]";
16+
}
17+
return result;
18+
}
19+
20+
// Bind data for adbc_execute
21+
struct AdbcExecuteBindData : public FunctionData {
22+
int64_t connection_id;
23+
string query;
24+
shared_ptr<AdbcConnectionWrapper> connection;
25+
vector<Value> params;
26+
vector<LogicalType> param_types;
27+
bool has_params = false;
28+
29+
unique_ptr<FunctionData> Copy() const override {
30+
auto copy = make_uniq<AdbcExecuteBindData>();
31+
copy->connection_id = connection_id;
32+
copy->query = query;
33+
copy->connection = connection;
34+
copy->params = params;
35+
copy->param_types = param_types;
36+
copy->has_params = has_params;
37+
return std::move(copy);
38+
}
39+
40+
bool Equals(const FunctionData &other_p) const override {
41+
auto &other = other_p.Cast<AdbcExecuteBindData>();
42+
return connection_id == other.connection_id && query == other.query;
43+
}
44+
};
45+
46+
// Bind function for adbc_execute
47+
static unique_ptr<FunctionData> AdbcExecuteBind(ClientContext &context, ScalarFunction &bound_function,
48+
vector<unique_ptr<Expression>> &arguments) {
49+
auto bind_data = make_uniq<AdbcExecuteBindData>();
50+
return std::move(bind_data);
51+
}
52+
53+
// Helper to execute a single DDL/DML statement and return rows affected
54+
static int64_t ExecuteStatement(int64_t connection_id, const string &query) {
55+
// Look up and validate connection
56+
auto connection = GetValidatedConnection(connection_id, "adbc_execute");
57+
58+
// Create and prepare statement
59+
auto statement = make_shared_ptr<AdbcStatementWrapper>(connection);
60+
statement->Init();
61+
statement->SetSqlQuery(query);
62+
63+
try {
64+
statement->Prepare();
65+
} catch (Exception &e) {
66+
throw InvalidInputException(FormatError("adbc_execute: Failed to prepare statement: " + string(e.what()), query));
67+
}
68+
69+
// Execute the statement
70+
ArrowArrayStream stream;
71+
memset(&stream, 0, sizeof(stream));
72+
int64_t rows_affected = -1;
73+
74+
try {
75+
statement->ExecuteQuery(&stream, &rows_affected);
76+
} catch (Exception &e) {
77+
throw IOException(FormatError("adbc_execute: Failed to execute statement: " + string(e.what()), query));
78+
}
79+
80+
// Release the stream if it was created (DDL/DML may or may not create one)
81+
if (stream.release) {
82+
stream.release(&stream);
83+
}
84+
85+
// Return rows affected (or 0 if not available)
86+
return rows_affected >= 0 ? rows_affected : 0;
87+
}
88+
89+
// Execute function - runs DDL/DML and returns rows affected
90+
static void AdbcExecuteFunction(DataChunk &args, ExpressionState &state, Vector &result) {
91+
auto &conn_vector = args.data[0];
92+
auto &query_vector = args.data[1];
93+
auto count = args.size();
94+
95+
// Handle constant input (for constant folding optimization)
96+
if (conn_vector.GetVectorType() == VectorType::CONSTANT_VECTOR &&
97+
query_vector.GetVectorType() == VectorType::CONSTANT_VECTOR) {
98+
if (ConstantVector::IsNull(conn_vector)) {
99+
throw InvalidInputException("adbc_execute: Connection handle cannot be NULL");
100+
}
101+
if (ConstantVector::IsNull(query_vector)) {
102+
throw InvalidInputException("adbc_execute: Query cannot be NULL");
103+
}
104+
auto connection_id = conn_vector.GetValue(0).GetValue<int64_t>();
105+
auto query = query_vector.GetValue(0).GetValue<string>();
106+
auto rows_affected = ExecuteStatement(connection_id, query);
107+
result.SetVectorType(VectorType::CONSTANT_VECTOR);
108+
ConstantVector::GetData<int64_t>(result)[0] = rows_affected;
109+
return;
110+
}
111+
112+
// Handle flat/dictionary vectors
113+
result.SetVectorType(VectorType::FLAT_VECTOR);
114+
auto result_data = FlatVector::GetData<int64_t>(result);
115+
auto &validity = FlatVector::Validity(result);
116+
117+
for (idx_t row_idx = 0; row_idx < count; row_idx++) {
118+
auto conn_value = conn_vector.GetValue(row_idx);
119+
auto query_value = query_vector.GetValue(row_idx);
120+
121+
if (conn_value.IsNull()) {
122+
throw InvalidInputException("adbc_execute: Connection handle cannot be NULL");
123+
}
124+
if (query_value.IsNull()) {
125+
throw InvalidInputException("adbc_execute: Query cannot be NULL");
126+
}
127+
128+
auto connection_id = conn_value.GetValue<int64_t>();
129+
auto query = query_value.GetValue<string>();
130+
result_data[row_idx] = ExecuteStatement(connection_id, query);
131+
}
132+
}
133+
134+
// Register adbc_execute scalar function
135+
void RegisterAdbcExecuteFunction(DatabaseInstance &db) {
136+
ExtensionLoader loader(db, "adbc");
137+
138+
ScalarFunction adbc_execute_function(
139+
"adbc_execute",
140+
{LogicalType::BIGINT, LogicalType::VARCHAR},
141+
LogicalType::BIGINT,
142+
AdbcExecuteFunction,
143+
AdbcExecuteBind
144+
);
145+
// Disable automatic NULL propagation so we can throw a meaningful error
146+
adbc_execute_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
147+
148+
CreateScalarFunctionInfo info(adbc_execute_function);
149+
FunctionDescription desc;
150+
desc.description = "Execute DDL/DML statements (CREATE, INSERT, UPDATE, DELETE) on an ADBC connection";
151+
desc.parameter_names = {"connection_handle", "query"};
152+
desc.parameter_types = {LogicalType::BIGINT, LogicalType::VARCHAR};
153+
desc.examples = {"SELECT adbc_execute(conn, 'CREATE TABLE test (id INTEGER)')",
154+
"SELECT adbc_execute(conn, 'INSERT INTO test VALUES (1)')",
155+
"SELECT adbc_execute(conn, 'DELETE FROM test WHERE id = 1')"};
156+
desc.categories = {"adbc"};
157+
info.descriptions.push_back(std::move(desc));
158+
loader.RegisterFunction(info);
159+
}
160+
161+
} // namespace adbc_scanner

0 commit comments

Comments
 (0)