|
6 | 6 | #include "duckdb/common/string_util.hpp" |
7 | 7 | #include "duckdb/function/scalar_function.hpp" |
8 | 8 | #include "duckdb/main/extension_util.hpp" |
| 9 | +#include "duckdb/parser/parsed_data/create_macro_info.hpp" |
9 | 10 | #include <duckdb/parser/parsed_data/create_scalar_function_info.hpp> |
| 11 | +#include "duckdb/common/exception/http_exception.hpp" |
10 | 12 |
|
11 | | -// OpenSSL linked through vcpkg |
12 | | -#include <openssl/opensslv.h> |
| 13 | +#define CPPHTTPLIB_OPENSSL_SUPPORT |
| 14 | +#include "httplib.hpp" |
| 15 | +#include "yyjson.hpp" |
13 | 16 |
|
14 | 17 | namespace duckdb { |
15 | 18 |
|
16 | | -inline void WebxtensionScalarFun(DataChunk &args, ExpressionState &state, Vector &result) { |
17 | | - auto &name_vector = args.data[0]; |
18 | | - UnaryExecutor::Execute<string_t, string_t>( |
19 | | - name_vector, result, args.size(), |
20 | | - [&](string_t name) { |
21 | | - return StringVector::AddString(result, "Webxtension "+name.GetString()+" 🐥");; |
22 | | - }); |
| 19 | +// Helper function to setup HTTP client |
| 20 | +static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient(const std::string &url) { |
| 21 | + std::string scheme, domain, path; |
| 22 | + size_t pos = url.find("://"); |
| 23 | + std::string mod_url = url; |
| 24 | + if (pos != std::string::npos) { |
| 25 | + scheme = mod_url.substr(0, pos); |
| 26 | + mod_url.erase(0, pos + 3); |
| 27 | + } |
| 28 | + |
| 29 | + pos = mod_url.find("/"); |
| 30 | + if (pos != std::string::npos) { |
| 31 | + domain = mod_url.substr(0, pos); |
| 32 | + path = mod_url.substr(pos); |
| 33 | + } else { |
| 34 | + domain = mod_url; |
| 35 | + path = "/"; |
| 36 | + } |
| 37 | + |
| 38 | + duckdb_httplib_openssl::Client client(domain.c_str()); |
| 39 | + client.set_read_timeout(10, 0); |
| 40 | + client.set_follow_location(true); |
| 41 | + |
| 42 | + return std::make_pair(std::move(client), path); |
23 | 43 | } |
24 | 44 |
|
25 | | -inline void WebxtensionOpenSSLVersionScalarFun(DataChunk &args, ExpressionState &state, Vector &result) { |
26 | | - auto &name_vector = args.data[0]; |
| 45 | +// Helper function to handle HTTP errors |
| 46 | +static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std::string &request_type) { |
| 47 | + std::string err_message = "HTTP " + request_type + " request failed. "; |
| 48 | + |
| 49 | + switch (res.error()) { |
| 50 | + case duckdb_httplib_openssl::Error::Connection: |
| 51 | + err_message += "Connection error."; |
| 52 | + break; |
| 53 | + case duckdb_httplib_openssl::Error::Read: |
| 54 | + err_message += "Error reading response."; |
| 55 | + break; |
| 56 | + default: |
| 57 | + err_message += "Unknown error."; |
| 58 | + break; |
| 59 | + } |
| 60 | + throw std::runtime_error(err_message); |
| 61 | +} |
| 62 | + |
| 63 | + |
| 64 | +static bool ContainsMacroDefinition(const std::string &content) { |
| 65 | + std::string upper_content = StringUtil::Upper(content); |
| 66 | + const char* patterns[] = { |
| 67 | + "CREATE MACRO", |
| 68 | + "CREATE OR REPLACE MACRO", |
| 69 | + "CREATE TEMP MACRO", |
| 70 | + "CREATE TEMPORARY MACRO", |
| 71 | + "CREATE OR REPLACE TEMP MACRO", |
| 72 | + "CREATE OR REPLACE TEMPORARY MACRO" |
| 73 | + }; |
| 74 | + |
| 75 | + for (const auto& pattern : patterns) { |
| 76 | + if (upper_content.find(pattern) != std::string::npos) { |
| 77 | + return true; |
| 78 | + } |
| 79 | + } |
| 80 | + return false; |
| 81 | +} |
| 82 | + |
| 83 | +// Function to fetch and create macro from URL |
| 84 | +static void LoadMacroFromUrlFunction(DataChunk &args, ExpressionState &state, Vector &result, DatabaseInstance *db_instance) { |
| 85 | + auto &context = state.GetContext(); |
| 86 | + |
27 | 87 | UnaryExecutor::Execute<string_t, string_t>( |
28 | | - name_vector, result, args.size(), |
29 | | - [&](string_t name) { |
30 | | - return StringVector::AddString(result, "Webxtension " + name.GetString() + |
31 | | - ", my linked OpenSSL version is " + |
32 | | - OPENSSL_VERSION_TEXT );; |
| 88 | + args.data[0], result, args.size(), |
| 89 | + [&](string_t url) { |
| 90 | + try { |
| 91 | + // Setup HTTP client |
| 92 | + auto client_and_path = SetupHttpClient(url.GetString()); |
| 93 | + auto &client = client_and_path.first; |
| 94 | + auto &path = client_and_path.second; |
| 95 | + |
| 96 | + // Make GET request |
| 97 | + auto res = client.Get(path.c_str()); |
| 98 | + if (!res) { |
| 99 | + HandleHttpError(res, "GET"); |
| 100 | + } |
| 101 | + |
| 102 | + if (res->status != 200) { |
| 103 | + throw std::runtime_error("HTTP error " + std::to_string(res->status) + ": " + res->reason); |
| 104 | + } |
| 105 | + |
| 106 | + // Get the SQL content |
| 107 | + std::string macro_sql = res->body; |
| 108 | + |
| 109 | + // Replace all \r\n with \n |
| 110 | + macro_sql = StringUtil::Replace(macro_sql, "\r\n", "\n"); |
| 111 | + // Replace any remaining \r with \n |
| 112 | + macro_sql = StringUtil::Replace(macro_sql, "\r", "\n"); |
| 113 | + // Normalize multiple newlines to single newlines |
| 114 | + macro_sql = StringUtil::Replace(macro_sql, "\n\n", "\n"); |
| 115 | + // Trim in place |
| 116 | + StringUtil::Trim(macro_sql); |
| 117 | + |
| 118 | + if (!ContainsMacroDefinition(macro_sql)) { |
| 119 | + throw std::runtime_error("URL content does not contain a valid macro definition"); |
| 120 | + } |
| 121 | + |
| 122 | + //std::cout << macro_sql << "\n"; |
| 123 | + Connection conn(*db_instance); |
| 124 | + |
| 125 | + // Execute the macro directly in the current context |
| 126 | + auto query_result = conn.Query(macro_sql); |
| 127 | + |
| 128 | + if (query_result->HasError()) { |
| 129 | + throw std::runtime_error("Failed loading Macro: " + query_result->GetError()); |
| 130 | + } |
| 131 | + |
| 132 | + return StringVector::AddString(result, "Successfully loaded macro"); |
| 133 | + |
| 134 | + } catch (std::exception &e) { |
| 135 | + std::string error_msg = "Error: " + std::string(e.what()); |
| 136 | + throw std::runtime_error(error_msg); |
| 137 | + } |
33 | 138 | }); |
34 | 139 | } |
35 | 140 |
|
36 | 141 | static void LoadInternal(DatabaseInstance &instance) { |
37 | | - // Register a scalar function |
38 | | - auto webxtension_scalar_function = ScalarFunction("webxtension", {LogicalType::VARCHAR}, LogicalType::VARCHAR, WebxtensionScalarFun); |
39 | | - ExtensionUtil::RegisterFunction(instance, webxtension_scalar_function); |
40 | | - |
41 | | - // Register another scalar function |
42 | | - auto webxtension_openssl_version_scalar_function = ScalarFunction("webxtension_openssl_version", {LogicalType::VARCHAR}, |
43 | | - LogicalType::VARCHAR, WebxtensionOpenSSLVersionScalarFun); |
44 | | - ExtensionUtil::RegisterFunction(instance, webxtension_openssl_version_scalar_function); |
| 142 | + // Create lambda to capture database instance |
| 143 | + auto load_macro_func = [&instance](DataChunk &args, ExpressionState &state, Vector &result) { |
| 144 | + LoadMacroFromUrlFunction(args, state, result, &instance); |
| 145 | + }; |
| 146 | + |
| 147 | + // Register function with captured database instance |
| 148 | + ExtensionUtil::RegisterFunction( |
| 149 | + instance, |
| 150 | + ScalarFunction("load_macro_from_url", {LogicalType::VARCHAR}, |
| 151 | + LogicalType::VARCHAR, load_macro_func) |
| 152 | + ); |
45 | 153 | } |
46 | 154 |
|
47 | 155 | void WebxtensionExtension::Load(DuckDB &db) { |
48 | | - LoadInternal(*db.instance); |
| 156 | + LoadInternal(*db.instance); |
49 | 157 | } |
| 158 | + |
50 | 159 | std::string WebxtensionExtension::Name() { |
51 | | - return "webxtension"; |
| 160 | + return "webxtension"; |
52 | 161 | } |
53 | 162 |
|
54 | 163 | std::string WebxtensionExtension::Version() const { |
55 | 164 | #ifdef EXT_VERSION_WEBXTENSION |
56 | | - return EXT_VERSION_WEBXTENSION; |
| 165 | + return EXT_VERSION_WEBXTENSION; |
57 | 166 | #else |
58 | | - return ""; |
| 167 | + return ""; |
59 | 168 | #endif |
60 | 169 | } |
61 | 170 |
|
62 | 171 | } // namespace duckdb |
63 | 172 |
|
64 | 173 | extern "C" { |
65 | | - |
66 | 174 | DUCKDB_EXTENSION_API void webxtension_init(duckdb::DatabaseInstance &db) { |
67 | 175 | duckdb::DuckDB db_wrapper(db); |
68 | 176 | db_wrapper.LoadExtension<duckdb::WebxtensionExtension>(); |
69 | 177 | } |
70 | 178 |
|
71 | 179 | DUCKDB_EXTENSION_API const char *webxtension_version() { |
72 | | - return duckdb::DuckDB::LibraryVersion(); |
| 180 | + return duckdb::DuckDB::LibraryVersion(); |
73 | 181 | } |
74 | 182 | } |
75 | 183 |
|
|
0 commit comments