diff --git a/src/functions/get_tranco.cpp b/src/functions/get_tranco.cpp index f53bedd..406d49e 100644 --- a/src/functions/get_tranco.cpp +++ b/src/functions/get_tranco.cpp @@ -17,14 +17,14 @@ namespace duckdb // Function to get the download code for the Tranco list std::string GetTrancoDownloadCode (char *date) { - CURL *curl = CreateCurlHandler (); + CURL *curl = CreateCurlHandler (WriteStringCallback); CURLcode res; std::string readBuffer; // Construct the URL for the daily list std::string url = "https://tranco-list.eu/daily_list?date=" + std::string (date) + "&subdomains=true"; - LogMessage (LogLevel::INFO, "Get Tranco download code for date: " + std::string (date)); + LogMessage (LogLevel::LOG_INFO, "Get Tranco download code for date: " + std::string (date)); curl_easy_setopt (curl, CURLOPT_URL, url.c_str ()); curl_easy_setopt (curl, CURLOPT_WRITEDATA, &readBuffer); @@ -33,8 +33,8 @@ namespace duckdb if (res != CURLE_OK) { - LogMessage (LogLevel::ERROR, std::string (curl_easy_strerror (res))); - LogMessage (LogLevel::CRITICAL, "Failed to fetch Tranco download code."); + LogMessage (LogLevel::LOG_ERROR, std::string (curl_easy_strerror (res))); + LogMessage (LogLevel::LOG_CRITICAL, "Failed to fetch Tranco download code."); } // Extract the download code from the URL @@ -42,11 +42,11 @@ namespace duckdb std::smatch code_match; if (std::regex_search (readBuffer, code_match, code_regex) && code_match.size () > 1) { - LogMessage (LogLevel::INFO, "Tranco download code: " + code_match[1].str ()); + LogMessage (LogLevel::LOG_INFO, "Tranco download code: " + code_match[1].str ()); return code_match[1].str (); } - LogMessage (LogLevel::CRITICAL, "Failed to extract Tranco download code."); + LogMessage (LogLevel::LOG_CRITICAL, "Failed to extract Tranco download code."); } // Function to download the Tranco list and create a table @@ -78,16 +78,16 @@ namespace duckdb // Construct the download URL std::string download_url = "https://tranco-list.eu/download/" + download_code + "/full"; - LogMessage (LogLevel::INFO, "Download Tranco list: " + download_url); + LogMessage (LogLevel::LOG_INFO, "Download Tranco list: " + download_url); // Download the CSV file to a temporary file - CURL *curl = CreateCurlHandler (); + CURL *curl = CreateCurlHandler (WriteFileCallback); CURLcode res; FILE *file = fopen (temp_file.c_str (), "wb"); if (!file) { curl_easy_cleanup (curl); - LogMessage (LogLevel::CRITICAL, "Failed to create temporary file for Tranco list."); + LogMessage (LogLevel::LOG_CRITICAL, "Failed to create temporary file for Tranco list."); } curl_easy_setopt (curl, CURLOPT_URL, download_url.c_str ()); @@ -99,18 +99,18 @@ namespace duckdb if (res != CURLE_OK) { remove (temp_file.c_str ()); // Clean up the temporary file - LogMessage (LogLevel::ERROR, std::string (curl_easy_strerror (res))); - LogMessage (LogLevel::CRITICAL, "Failed to download Tranco list. Check logs for details."); + LogMessage (LogLevel::LOG_ERROR, std::string (curl_easy_strerror (res))); + LogMessage (LogLevel::LOG_CRITICAL, "Failed to download Tranco list. Check logs for details."); } } if (!file.good ()) { - LogMessage (LogLevel::CRITICAL, "Tranco list `" + temp_file + "` not found. Download it first using `SELECT update_tranco(true);`"); + LogMessage (LogLevel::LOG_CRITICAL, "Tranco list `" + temp_file + "` not found. Download it first using `SELECT update_tranco(true);`"); } // Parse the CSV data and insert into a table - LogMessage (LogLevel::INFO, "Inserting Tranco list into table"); + LogMessage (LogLevel::LOG_INFO, "Inserting Tranco list into table"); Connection con (db); string query = "CREATE OR REPLACE TABLE tranco_list AS" @@ -133,7 +133,7 @@ namespace duckdb if (result->HasError ()) { - LogMessage (LogLevel::CRITICAL, result->GetError ()); + LogMessage (LogLevel::LOG_CRITICAL, result->GetError ()); } } @@ -162,7 +162,7 @@ namespace duckdb if (table_exists->RowCount () == 0) { - LogMessage (LogLevel::CRITICAL, "Tranco table not found. Download it first using `SELECT update_tranco(true);`"); + LogMessage (LogLevel::LOG_CRITICAL, "Tranco table not found. Download it first using `SELECT update_tranco(true);`"); } // Extract the input from the arguments @@ -201,7 +201,7 @@ namespace duckdb if (table_exists->RowCount () == 0) { - LogMessage (LogLevel::CRITICAL, "Tranco table not found. Download it first using `SELECT update_tranco(true);`"); + LogMessage (LogLevel::LOG_CRITICAL, "Tranco table not found. Download it first using `SELECT update_tranco(true);`"); } // Extract the input from the arguments diff --git a/src/utils/logger.cpp b/src/utils/logger.cpp index 98bfb36..d5ebffd 100644 --- a/src/utils/logger.cpp +++ b/src/utils/logger.cpp @@ -15,24 +15,24 @@ namespace duckdb const char *env_level = std::getenv ("LOG_LEVEL"); if (env_level == nullptr) { - return LogLevel::INFO; // default level + return LogLevel::LOG_INFO; // default level } std::string level_str (env_level); if (level_str == "DEBUG") - return LogLevel::DEBUG; + return LogLevel::LOG_DEBUG; if (level_str == "INFO") - return LogLevel::INFO; + return LogLevel::LOG_INFO; if (level_str == "WARNING") - return LogLevel::WARNING; + return LogLevel::LOG_WARNING; if (level_str == "ERROR") - return LogLevel::ERROR; + return LogLevel::LOG_ERROR; if (level_str == "CRITICAL") - return LogLevel::CRITICAL; + return LogLevel::LOG_CRITICAL; std::cerr << "Unknown LOG_LEVEL environment variable value: " << level_str << ". Defaulting to INFO." << std::endl; - return LogLevel::INFO; + return LogLevel::LOG_INFO; } std::string getCurrentTimestamp () @@ -57,11 +57,11 @@ namespace duckdb const char *level_str = ""; switch (level) { - case LogLevel::DEBUG: level_str = "DEBUG"; break; - case LogLevel::INFO: level_str = "INFO"; break; - case LogLevel::WARNING: level_str = "WARNING"; break; - case LogLevel::ERROR: level_str = "ERROR"; break; - case LogLevel::CRITICAL: level_str = "CRITICAL"; break; + case LogLevel::LOG_DEBUG: level_str = "DEBUG"; break; + case LogLevel::LOG_INFO: level_str = "INFO"; break; + case LogLevel::LOG_WARNING: level_str = "WARNING"; break; + case LogLevel::LOG_ERROR: level_str = "ERROR"; break; + case LogLevel::LOG_CRITICAL: level_str = "CRITICAL"; break; } std::ofstream log_file ("netquack.log", std::ios_base::app); @@ -76,13 +76,13 @@ namespace duckdb << message << std::endl; // Also output to stderr for error levels - if (level >= LogLevel::ERROR) + if (level >= LogLevel::LOG_ERROR) { std::cerr << "[" << level_str << "] " << message << std::endl; } // Throw exception for critical errors - if (level == LogLevel::CRITICAL) + if (level == LogLevel::LOG_CRITICAL) { throw std::runtime_error (message); } diff --git a/src/utils/logger.hpp b/src/utils/logger.hpp index 4ab3292..3924314 100644 --- a/src/utils/logger.hpp +++ b/src/utils/logger.hpp @@ -8,13 +8,14 @@ namespace duckdb { namespace netquack { + // Note: `LOG_` prefix is to avoid problems with DEBUG and ERROR macros enum class LogLevel { - DEBUG, - INFO, - WARNING, - ERROR, - CRITICAL + LOG_DEBUG, + LOG_INFO, + LOG_WARNING, + LOG_ERROR, + LOG_CRITICAL }; // Function to log messages with a specified log level diff --git a/src/utils/utils.cpp b/src/utils/utils.cpp index 4fd9087..02c55af 100644 --- a/src/utils/utils.cpp +++ b/src/utils/utils.cpp @@ -32,12 +32,12 @@ namespace duckdb #endif } - CURL *CreateCurlHandler () + CURL *CreateCurlHandler (curl_write_callback write_callback) { CURL *curl = curl_easy_init (); if (!curl) { - LogMessage (LogLevel::CRITICAL, "Failed to initialize CURL"); + LogMessage (LogLevel::LOG_CRITICAL, "Failed to initialize CURL"); } const char *ca_info = std::getenv ("CURL_CA_INFO"); @@ -63,34 +63,41 @@ namespace duckdb } #endif curl_easy_setopt (curl, CURLOPT_FOLLOWLOCATION, 1L); // Follow redirects - curl_easy_setopt (curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt (curl, CURLOPT_WRITEFUNCTION, write_callback); + if (ca_info) { // Set the custom CA certificate bundle file // https://github.com/hatamiarash7/duckdb-netquack/issues/6 - LogMessage (LogLevel::DEBUG, "Using custom CA certificate bundle: " + std::string (ca_info)); + LogMessage (LogLevel::LOG_DEBUG, "Using custom CA certificate bundle: " + std::string (ca_info)); curl_easy_setopt (curl, CURLOPT_CAINFO, ca_info); } const char *ca_path = std::getenv ("CURL_CA_PATH"); if (ca_path) { // Set the custom CA certificate directory - LogMessage (LogLevel::DEBUG, "Using custom CA certificate directory: " + std::string (ca_path)); + LogMessage (LogLevel::LOG_DEBUG, "Using custom CA certificate directory: " + std::string (ca_path)); curl_easy_setopt (curl, CURLOPT_CAPATH, ca_path); } return curl; } - size_t WriteCallback (void *contents, size_t size, size_t nmemb, void *userp) + size_t WriteStringCallback (char *contents, size_t size, size_t nmemb, void *userp) { ((std::string *)userp)->append ((char *)contents, size * nmemb); return size * nmemb; } + size_t WriteFileCallback (char *contents, size_t size, size_t nmemb, void *userp) + { + FILE *file = (FILE *)userp; + return fwrite (contents, size, nmemb, file); + } + std::string DownloadPublicSuffixList () { - CURL *curl = CreateCurlHandler (); + CURL *curl = CreateCurlHandler (WriteStringCallback); CURLcode res; std::string readBuffer; @@ -101,8 +108,8 @@ namespace duckdb if (res != CURLE_OK) { - LogMessage (LogLevel::ERROR, std::string (curl_easy_strerror (res))); - LogMessage (LogLevel::CRITICAL, "Failed to download public suffix list. Check logs for details."); + LogMessage (LogLevel::LOG_ERROR, std::string (curl_easy_strerror (res))); + LogMessage (LogLevel::LOG_CRITICAL, "Failed to download public suffix list. Check logs for details."); } return readBuffer; @@ -118,14 +125,14 @@ namespace duckdb if (table_exists->RowCount () == 0 || table_data->RowCount () <= 1 || force) { - LogMessage (LogLevel::INFO, "Loading public suffix list..."); + LogMessage (LogLevel::LOG_INFO, "Loading public suffix list..."); // Download the list auto list_data = DownloadPublicSuffixList (); // Validate the downloaded data if (list_data.empty ()) { - LogMessage (LogLevel::CRITICAL, "Failed to download public suffix list: empty data received"); + LogMessage (LogLevel::LOG_CRITICAL, "Failed to download public suffix list: empty data received"); } // Count non-comment/non-empty lines for validation @@ -143,11 +150,11 @@ namespace duckdb if (valid_line_count <= 1) { - LogMessage (LogLevel::ERROR, validation_line); - LogMessage (LogLevel::CRITICAL, "Downloaded public suffix list contains no valid entries. Try again or run `SELECT update_suffixes();`."); + LogMessage (LogLevel::LOG_ERROR, validation_line); + LogMessage (LogLevel::LOG_CRITICAL, "Downloaded public suffix list contains no valid entries. Try again or run `SELECT update_suffixes();`."); } - LogMessage (LogLevel::INFO, "Downloaded public suffix list with " + std::to_string (valid_line_count) + " valid entries"); + LogMessage (LogLevel::LOG_INFO, "Downloaded public suffix list with " + std::to_string (valid_line_count) + " valid entries"); // Parse the list and insert into a table std::istringstream stream (list_data); diff --git a/src/utils/utils.hpp b/src/utils/utils.hpp index 7d18f13..a272c2a 100644 --- a/src/utils/utils.hpp +++ b/src/utils/utils.hpp @@ -10,11 +10,14 @@ namespace duckdb { namespace netquack { - // Function to get a CURL handler - CURL *CreateCurlHandler (); + // Function to get a CURL handler with custom write callback + CURL *CreateCurlHandler (curl_write_callback write_callback); - // Function to download a file from a URL - size_t WriteCallback (void *contents, size_t size, size_t nmemb, void *userp); + // Function to write data to a string (for HTTP responses) + size_t WriteStringCallback (char *contents, size_t size, size_t nmemb, void *userp); + + // Function to write data to a file (for file downloads) + size_t WriteFileCallback (char *contents, size_t size, size_t nmemb, void *userp); // Function to download the public suffix list std::string DownloadPublicSuffixList (); diff --git a/test/sql/update_tranco.test_slow b/test/sql/update_tranco.test_slow new file mode 100644 index 0000000..e99af97 --- /dev/null +++ b/test/sql/update_tranco.test_slow @@ -0,0 +1,10 @@ +# name: test/sql/update_tranco.test_slow +# description: test netquack extension update_tranco function +# group: [netquack] + +require netquack + +query I +SELECT update_tranco(true); +---- +Tranco list updated