Skip to content

Commit 813c32c

Browse files
authored
Merge pull request #12 from altertable-ai/su/fix-curl-segv
Fix cURL's WRITEFUNCTION undefined behavior leading to segfault
2 parents 9d1b26f + 98f9885 commit 813c32c

File tree

6 files changed

+74
-53
lines changed

6 files changed

+74
-53
lines changed

src/functions/get_tranco.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ namespace duckdb
1717
// Function to get the download code for the Tranco list
1818
std::string GetTrancoDownloadCode (char *date)
1919
{
20-
CURL *curl = CreateCurlHandler ();
20+
CURL *curl = CreateCurlHandler (WriteStringCallback);
2121
CURLcode res;
2222
std::string readBuffer;
2323

2424
// Construct the URL for the daily list
2525
std::string url = "https://tranco-list.eu/daily_list?date=" + std::string (date) + "&subdomains=true";
2626

27-
LogMessage (LogLevel::INFO, "Get Tranco download code for date: " + std::string (date));
27+
LogMessage (LogLevel::LOG_INFO, "Get Tranco download code for date: " + std::string (date));
2828

2929
curl_easy_setopt (curl, CURLOPT_URL, url.c_str ());
3030
curl_easy_setopt (curl, CURLOPT_WRITEDATA, &readBuffer);
@@ -33,20 +33,20 @@ namespace duckdb
3333

3434
if (res != CURLE_OK)
3535
{
36-
LogMessage (LogLevel::ERROR, std::string (curl_easy_strerror (res)));
37-
LogMessage (LogLevel::CRITICAL, "Failed to fetch Tranco download code.");
36+
LogMessage (LogLevel::LOG_ERROR, std::string (curl_easy_strerror (res)));
37+
LogMessage (LogLevel::LOG_CRITICAL, "Failed to fetch Tranco download code.");
3838
}
3939

4040
// Extract the download code from the URL
4141
std::regex code_regex (R"(Information on the Tranco list with ID ([A-Z0-9]+))");
4242
std::smatch code_match;
4343
if (std::regex_search (readBuffer, code_match, code_regex) && code_match.size () > 1)
4444
{
45-
LogMessage (LogLevel::INFO, "Tranco download code: " + code_match[1].str ());
45+
LogMessage (LogLevel::LOG_INFO, "Tranco download code: " + code_match[1].str ());
4646
return code_match[1].str ();
4747
}
4848

49-
LogMessage (LogLevel::CRITICAL, "Failed to extract Tranco download code.");
49+
LogMessage (LogLevel::LOG_CRITICAL, "Failed to extract Tranco download code.");
5050
}
5151

5252
// Function to download the Tranco list and create a table
@@ -78,16 +78,16 @@ namespace duckdb
7878
// Construct the download URL
7979
std::string download_url = "https://tranco-list.eu/download/" + download_code + "/full";
8080

81-
LogMessage (LogLevel::INFO, "Download Tranco list: " + download_url);
81+
LogMessage (LogLevel::LOG_INFO, "Download Tranco list: " + download_url);
8282

8383
// Download the CSV file to a temporary file
84-
CURL *curl = CreateCurlHandler ();
84+
CURL *curl = CreateCurlHandler (WriteFileCallback);
8585
CURLcode res;
8686
FILE *file = fopen (temp_file.c_str (), "wb");
8787
if (!file)
8888
{
8989
curl_easy_cleanup (curl);
90-
LogMessage (LogLevel::CRITICAL, "Failed to create temporary file for Tranco list.");
90+
LogMessage (LogLevel::LOG_CRITICAL, "Failed to create temporary file for Tranco list.");
9191
}
9292

9393
curl_easy_setopt (curl, CURLOPT_URL, download_url.c_str ());
@@ -99,18 +99,18 @@ namespace duckdb
9999
if (res != CURLE_OK)
100100
{
101101
remove (temp_file.c_str ()); // Clean up the temporary file
102-
LogMessage (LogLevel::ERROR, std::string (curl_easy_strerror (res)));
103-
LogMessage (LogLevel::CRITICAL, "Failed to download Tranco list. Check logs for details.");
102+
LogMessage (LogLevel::LOG_ERROR, std::string (curl_easy_strerror (res)));
103+
LogMessage (LogLevel::LOG_CRITICAL, "Failed to download Tranco list. Check logs for details.");
104104
}
105105
}
106106

107107
if (!file.good ())
108108
{
109-
LogMessage (LogLevel::CRITICAL, "Tranco list `" + temp_file + "` not found. Download it first using `SELECT update_tranco(true);`");
109+
LogMessage (LogLevel::LOG_CRITICAL, "Tranco list `" + temp_file + "` not found. Download it first using `SELECT update_tranco(true);`");
110110
}
111111

112112
// Parse the CSV data and insert into a table
113-
LogMessage (LogLevel::INFO, "Inserting Tranco list into table");
113+
LogMessage (LogLevel::LOG_INFO, "Inserting Tranco list into table");
114114

115115
Connection con (db);
116116
string query = "CREATE OR REPLACE TABLE tranco_list AS"
@@ -133,7 +133,7 @@ namespace duckdb
133133

134134
if (result->HasError ())
135135
{
136-
LogMessage (LogLevel::CRITICAL, result->GetError ());
136+
LogMessage (LogLevel::LOG_CRITICAL, result->GetError ());
137137
}
138138
}
139139

@@ -162,7 +162,7 @@ namespace duckdb
162162

163163
if (table_exists->RowCount () == 0)
164164
{
165-
LogMessage (LogLevel::CRITICAL, "Tranco table not found. Download it first using `SELECT update_tranco(true);`");
165+
LogMessage (LogLevel::LOG_CRITICAL, "Tranco table not found. Download it first using `SELECT update_tranco(true);`");
166166
}
167167

168168
// Extract the input from the arguments
@@ -201,7 +201,7 @@ namespace duckdb
201201

202202
if (table_exists->RowCount () == 0)
203203
{
204-
LogMessage (LogLevel::CRITICAL, "Tranco table not found. Download it first using `SELECT update_tranco(true);`");
204+
LogMessage (LogLevel::LOG_CRITICAL, "Tranco table not found. Download it first using `SELECT update_tranco(true);`");
205205
}
206206

207207
// Extract the input from the arguments

src/utils/logger.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,24 @@ namespace duckdb
1515
const char *env_level = std::getenv ("LOG_LEVEL");
1616
if (env_level == nullptr)
1717
{
18-
return LogLevel::INFO; // default level
18+
return LogLevel::LOG_INFO; // default level
1919
}
2020

2121
std::string level_str (env_level);
2222
if (level_str == "DEBUG")
23-
return LogLevel::DEBUG;
23+
return LogLevel::LOG_DEBUG;
2424
if (level_str == "INFO")
25-
return LogLevel::INFO;
25+
return LogLevel::LOG_INFO;
2626
if (level_str == "WARNING")
27-
return LogLevel::WARNING;
27+
return LogLevel::LOG_WARNING;
2828
if (level_str == "ERROR")
29-
return LogLevel::ERROR;
29+
return LogLevel::LOG_ERROR;
3030
if (level_str == "CRITICAL")
31-
return LogLevel::CRITICAL;
31+
return LogLevel::LOG_CRITICAL;
3232

3333
std::cerr << "Unknown LOG_LEVEL environment variable value: " << level_str
3434
<< ". Defaulting to INFO." << std::endl;
35-
return LogLevel::INFO;
35+
return LogLevel::LOG_INFO;
3636
}
3737

3838
std::string getCurrentTimestamp ()
@@ -57,11 +57,11 @@ namespace duckdb
5757
const char *level_str = "";
5858
switch (level)
5959
{
60-
case LogLevel::DEBUG: level_str = "DEBUG"; break;
61-
case LogLevel::INFO: level_str = "INFO"; break;
62-
case LogLevel::WARNING: level_str = "WARNING"; break;
63-
case LogLevel::ERROR: level_str = "ERROR"; break;
64-
case LogLevel::CRITICAL: level_str = "CRITICAL"; break;
60+
case LogLevel::LOG_DEBUG: level_str = "DEBUG"; break;
61+
case LogLevel::LOG_INFO: level_str = "INFO"; break;
62+
case LogLevel::LOG_WARNING: level_str = "WARNING"; break;
63+
case LogLevel::LOG_ERROR: level_str = "ERROR"; break;
64+
case LogLevel::LOG_CRITICAL: level_str = "CRITICAL"; break;
6565
}
6666

6767
std::ofstream log_file ("netquack.log", std::ios_base::app);
@@ -76,13 +76,13 @@ namespace duckdb
7676
<< message << std::endl;
7777

7878
// Also output to stderr for error levels
79-
if (level >= LogLevel::ERROR)
79+
if (level >= LogLevel::LOG_ERROR)
8080
{
8181
std::cerr << "[" << level_str << "] " << message << std::endl;
8282
}
8383

8484
// Throw exception for critical errors
85-
if (level == LogLevel::CRITICAL)
85+
if (level == LogLevel::LOG_CRITICAL)
8686
{
8787
throw std::runtime_error (message);
8888
}

src/utils/logger.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ namespace duckdb
88
{
99
namespace netquack
1010
{
11+
// Note: `LOG_` prefix is to avoid problems with DEBUG and ERROR macros
1112
enum class LogLevel
1213
{
13-
DEBUG,
14-
INFO,
15-
WARNING,
16-
ERROR,
17-
CRITICAL
14+
LOG_DEBUG,
15+
LOG_INFO,
16+
LOG_WARNING,
17+
LOG_ERROR,
18+
LOG_CRITICAL
1819
};
1920

2021
// Function to log messages with a specified log level

src/utils/utils.cpp

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ namespace duckdb
3232
#endif
3333
}
3434

35-
CURL *CreateCurlHandler ()
35+
CURL *CreateCurlHandler (curl_write_callback write_callback)
3636
{
3737
CURL *curl = curl_easy_init ();
3838
if (!curl)
3939
{
40-
LogMessage (LogLevel::CRITICAL, "Failed to initialize CURL");
40+
LogMessage (LogLevel::LOG_CRITICAL, "Failed to initialize CURL");
4141
}
4242

4343
const char *ca_info = std::getenv ("CURL_CA_INFO");
@@ -63,34 +63,41 @@ namespace duckdb
6363
}
6464
#endif
6565
curl_easy_setopt (curl, CURLOPT_FOLLOWLOCATION, 1L); // Follow redirects
66-
curl_easy_setopt (curl, CURLOPT_WRITEFUNCTION, WriteCallback);
66+
curl_easy_setopt (curl, CURLOPT_WRITEFUNCTION, write_callback);
67+
6768
if (ca_info)
6869
{
6970
// Set the custom CA certificate bundle file
7071
// https://github.com/hatamiarash7/duckdb-netquack/issues/6
71-
LogMessage (LogLevel::DEBUG, "Using custom CA certificate bundle: " + std::string (ca_info));
72+
LogMessage (LogLevel::LOG_DEBUG, "Using custom CA certificate bundle: " + std::string (ca_info));
7273
curl_easy_setopt (curl, CURLOPT_CAINFO, ca_info);
7374
}
7475
const char *ca_path = std::getenv ("CURL_CA_PATH");
7576
if (ca_path)
7677
{
7778
// Set the custom CA certificate directory
78-
LogMessage (LogLevel::DEBUG, "Using custom CA certificate directory: " + std::string (ca_path));
79+
LogMessage (LogLevel::LOG_DEBUG, "Using custom CA certificate directory: " + std::string (ca_path));
7980
curl_easy_setopt (curl, CURLOPT_CAPATH, ca_path);
8081
}
8182

8283
return curl;
8384
}
8485

85-
size_t WriteCallback (void *contents, size_t size, size_t nmemb, void *userp)
86+
size_t WriteStringCallback (char *contents, size_t size, size_t nmemb, void *userp)
8687
{
8788
((std::string *)userp)->append ((char *)contents, size * nmemb);
8889
return size * nmemb;
8990
}
9091

92+
size_t WriteFileCallback (char *contents, size_t size, size_t nmemb, void *userp)
93+
{
94+
FILE *file = (FILE *)userp;
95+
return fwrite (contents, size, nmemb, file);
96+
}
97+
9198
std::string DownloadPublicSuffixList ()
9299
{
93-
CURL *curl = CreateCurlHandler ();
100+
CURL *curl = CreateCurlHandler (WriteStringCallback);
94101
CURLcode res;
95102
std::string readBuffer;
96103

@@ -101,8 +108,8 @@ namespace duckdb
101108

102109
if (res != CURLE_OK)
103110
{
104-
LogMessage (LogLevel::ERROR, std::string (curl_easy_strerror (res)));
105-
LogMessage (LogLevel::CRITICAL, "Failed to download public suffix list. Check logs for details.");
111+
LogMessage (LogLevel::LOG_ERROR, std::string (curl_easy_strerror (res)));
112+
LogMessage (LogLevel::LOG_CRITICAL, "Failed to download public suffix list. Check logs for details.");
106113
}
107114

108115
return readBuffer;
@@ -118,14 +125,14 @@ namespace duckdb
118125

119126
if (table_exists->RowCount () == 0 || table_data->RowCount () <= 1 || force)
120127
{
121-
LogMessage (LogLevel::INFO, "Loading public suffix list...");
128+
LogMessage (LogLevel::LOG_INFO, "Loading public suffix list...");
122129
// Download the list
123130
auto list_data = DownloadPublicSuffixList ();
124131

125132
// Validate the downloaded data
126133
if (list_data.empty ())
127134
{
128-
LogMessage (LogLevel::CRITICAL, "Failed to download public suffix list: empty data received");
135+
LogMessage (LogLevel::LOG_CRITICAL, "Failed to download public suffix list: empty data received");
129136
}
130137

131138
// Count non-comment/non-empty lines for validation
@@ -143,11 +150,11 @@ namespace duckdb
143150

144151
if (valid_line_count <= 1)
145152
{
146-
LogMessage (LogLevel::ERROR, validation_line);
147-
LogMessage (LogLevel::CRITICAL, "Downloaded public suffix list contains no valid entries. Try again or run `SELECT update_suffixes();`.");
153+
LogMessage (LogLevel::LOG_ERROR, validation_line);
154+
LogMessage (LogLevel::LOG_CRITICAL, "Downloaded public suffix list contains no valid entries. Try again or run `SELECT update_suffixes();`.");
148155
}
149156

150-
LogMessage (LogLevel::INFO, "Downloaded public suffix list with " + std::to_string (valid_line_count) + " valid entries");
157+
LogMessage (LogLevel::LOG_INFO, "Downloaded public suffix list with " + std::to_string (valid_line_count) + " valid entries");
151158

152159
// Parse the list and insert into a table
153160
std::istringstream stream (list_data);

src/utils/utils.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@ namespace duckdb
1010
{
1111
namespace netquack
1212
{
13-
// Function to get a CURL handler
14-
CURL *CreateCurlHandler ();
13+
// Function to get a CURL handler with custom write callback
14+
CURL *CreateCurlHandler (curl_write_callback write_callback);
1515

16-
// Function to download a file from a URL
17-
size_t WriteCallback (void *contents, size_t size, size_t nmemb, void *userp);
16+
// Function to write data to a string (for HTTP responses)
17+
size_t WriteStringCallback (char *contents, size_t size, size_t nmemb, void *userp);
18+
19+
// Function to write data to a file (for file downloads)
20+
size_t WriteFileCallback (char *contents, size_t size, size_t nmemb, void *userp);
1821

1922
// Function to download the public suffix list
2023
std::string DownloadPublicSuffixList ();

test/sql/update_tranco.test_slow

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# name: test/sql/update_tranco.test_slow
2+
# description: test netquack extension update_tranco function
3+
# group: [netquack]
4+
5+
require netquack
6+
7+
query I
8+
SELECT update_tranco(true);
9+
----
10+
Tranco list updated

0 commit comments

Comments
 (0)