Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/include/s3fs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class S3FileSystem : public HTTPFileSystem {
// Helper class to do s3 ListObjectV2 api call https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html
struct AWSListObjectV2 {
static string Request(const string &path, HTTPParams &http_params, S3AuthParams &s3_auth_params,
string &continuation_token, optional_idx max_keys = optional_idx());
string &continuation_token, bool use_delimiter = false, optional_idx max_keys = optional_idx());
static void ParseFileList(string &aws_response, vector<OpenFileInfo> &result);
static vector<string> ParseCommonPrefix(string &aws_response);
static string ParseContinuationToken(string &aws_response);
Expand Down
73 changes: 54 additions & 19 deletions src/s3fs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "create_secret_functions.hpp"

#include <iostream>
#include <iostream>
#include <thread>
#ifdef EMSCRIPTEN
Expand Down Expand Up @@ -1193,19 +1194,25 @@ void S3FileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx
}

static bool Match(vector<string>::const_iterator key, vector<string>::const_iterator key_end,
vector<string>::const_iterator pattern, vector<string>::const_iterator pattern_end) {
vector<string>::const_iterator pattern, vector<string>::const_iterator pattern_end, bool completed) {

if (key == key_end && !completed) {
return true;
}

while (key != key_end && pattern != pattern_end) {
if (*pattern == "**") {
if (std::next(pattern) == pattern_end) {
return true;
}
pattern ++;
while (key != key_end) {
if (Match(key, key_end, std::next(pattern), pattern_end)) {
if (Match(key, key_end, pattern, pattern_end, completed)) {
return true;
}
key++;
}
if (!completed) return true;
return false;
}
if (!Glob(key->data(), key->length(), pattern->data(), pattern->length())) {
Expand All @@ -1214,6 +1221,9 @@ static bool Match(vector<string>::const_iterator key, vector<string>::const_iter
key++;
pattern++;
}
if (pattern != pattern_end && !completed) {
return true;
}
return key == key_end && pattern == pattern_end;
}

Expand Down Expand Up @@ -1284,12 +1294,23 @@ bool S3GlobResult::ExpandNextPath() const {
// we have common prefixes left to scan - perform the request
auto prefix_path = parsed_s3_url.prefix + parsed_s3_url.bucket + '/' + current_common_prefix;

auto prefix_res =
AWSListObjectV2::Request(prefix_path, *http_params, s3_auth_params, common_prefix_continuation_token);
AWSListObjectV2::ParseFileList(prefix_res, s3_keys);
auto more_prefixes = AWSListObjectV2::ParseCommonPrefix(prefix_res);
common_prefixes.insert(common_prefixes.end(), more_prefixes.begin(), more_prefixes.end());
common_prefix_continuation_token = AWSListObjectV2::ParseContinuationToken(prefix_res);

vector<string> pattern_splits = StringUtil::Split(parsed_s3_url.key, "/");
vector<string> key_splits = StringUtil::Split(current_common_prefix, "/");
//pattern_splits.resize(key_splits.size());
const bool is_match = Match(key_splits.begin(), key_splits.end(), pattern_splits.begin(), pattern_splits.end(), false);
if (is_match) {
prefix_path = S3FileSystem::UrlDecode(prefix_path);
auto prefix_res = AWSListObjectV2::Request(prefix_path, *http_params, s3_auth_params,
common_prefix_continuation_token, true);

AWSListObjectV2::ParseFileList(prefix_res, s3_keys);
auto more_prefixes = AWSListObjectV2::ParseCommonPrefix(prefix_res);
common_prefixes.insert(common_prefixes.end(), more_prefixes.begin(), more_prefixes.end());

common_prefix_continuation_token = AWSListObjectV2::ParseContinuationToken(prefix_res);
}

if (common_prefix_continuation_token.empty()) {
// we are done with the current common prefix
// either move on to the next one, or finish up
Expand All @@ -1308,7 +1329,7 @@ bool S3GlobResult::ExpandNextPath() const {
}
// issue the main request
string response_str =
AWSListObjectV2::Request(shared_path, *http_params, s3_auth_params, main_continuation_token);
AWSListObjectV2::Request(shared_path, *http_params, s3_auth_params, main_continuation_token, true);
main_continuation_token = AWSListObjectV2::ParseContinuationToken(response_str);
AWSListObjectV2::ParseFileList(response_str, s3_keys);

Expand All @@ -1330,7 +1351,7 @@ bool S3GlobResult::ExpandNextPath() const {
for (auto &s3_key : s3_keys) {

vector<string> key_splits = StringUtil::Split(s3_key.path, "/");
bool is_match = Match(key_splits.begin(), key_splits.end(), pattern_splits.begin(), pattern_splits.end());
bool is_match = Match(key_splits.begin(), key_splits.end(), pattern_splits.begin(), pattern_splits.end(), true);

if (is_match) {
auto result_full_url = parsed_s3_url.prefix + parsed_s3_url.bucket + "/" + s3_key.path;
Expand Down Expand Up @@ -1510,31 +1531,45 @@ HTTPException S3FileSystem::GetHTTPError(FileHandle &handle, const HTTPResponse
}

string AWSListObjectV2::Request(const string &path, HTTPParams &http_params, S3AuthParams &s3_auth_params,
string &continuation_token, optional_idx max_keys) {
string &continuation_token, bool use_delimiter, optional_idx max_keys) {
const idx_t MAX_RETRIES = 1;
for (idx_t it = 0; it <= MAX_RETRIES; it++) {
auto parsed_url = S3FileSystem::S3UrlParse(path, s3_auth_params);

// Construct the ListObjectsV2 call
string req_path = parsed_url.path.substr(0, parsed_url.path.length() - parsed_url.key.length());

string req_params;
map<string,string> req_params;
// NOTE: req_params needs to be sorted before passing to sigv4 code
if (!continuation_token.empty()) {
req_params += "continuation-token=" + S3FileSystem::UrlEncode(continuation_token, true);
req_params += "&";
req_params["continuation-token"] = S3FileSystem::UrlEncode(continuation_token, true);
}

if (use_delimiter) {
req_params["delimiter"] ="%2F";
}
req_params += "encoding-type=url&list-type=2";
req_params += "&prefix=" + S3FileSystem::UrlEncode(parsed_url.key, true);

req_params["encoding-type"] = "url";
req_params["list-type"] = "2";
if (max_keys.IsValid()) {
req_params += "&max-keys=" + to_string(max_keys.GetIndex());
req_params["max-keys"] = to_string(max_keys.GetIndex());
}
req_params["prefix"] = S3FileSystem::UrlEncode(parsed_url.key, true);

string encoded_params = "";
for (const auto & p : req_params) {
encoded_params += p.first + "=" + p.second + "&";
}
if (!encoded_params.empty()) {
// Remove last '&'
encoded_params.pop_back();
}
auto header_map =
CreateS3Header(req_path, req_params, parsed_url.host, "s3", "GET", s3_auth_params, "", "", "", "");
CreateS3Header(req_path, encoded_params, parsed_url.host, "s3", "GET", s3_auth_params, "", "", "", "");

// Get requests use fresh connection
string full_host = parsed_url.http_proto + parsed_url.host;
string listobjectv2_url = req_path + "?" + req_params;
string listobjectv2_url = req_path + "?" + encoded_params;
std::stringstream response;
ErrorData error;
GetRequestInfo get_request(
Expand Down
14 changes: 7 additions & 7 deletions test/sql/copy/s3/starstar.test
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ FROM GLOB('s3://test-bucket/glob_ss/*/*/t0.csv');
----
s3://test-bucket/glob_ss/a/b/t0.csv

query I
query I rowsort
FROM GLOB('s3://test-bucket/glob_ss/**/t0.csv');
----
s3://test-bucket/glob_ss/a/b/t0.csv
s3://test-bucket/glob_ss/a/t0.csv
s3://test-bucket/glob_ss/t0.csv

query I
query I rowsort
FROM GLOB('s3://test-bucket/glob_ss/**/*/t0.csv');
----
s3://test-bucket/glob_ss/a/b/t0.csv
Expand All @@ -143,27 +143,27 @@ query I
FROM GLOB('s3://test-bucket/glob_ss/*/*/*/t0.csv');
----

query I
query I rowsort
FROM GLOB('s3://test-bucket/glob_ss/**');
----
s3://test-bucket/glob_ss/a/b/t0.csv
s3://test-bucket/glob_ss/a/t0.csv
s3://test-bucket/glob_ss/t0.csv

query I
query I rowsort
FROM GLOB('s3://test-bucket/glob_ss/**/*');
----
s3://test-bucket/glob_ss/a/b/t0.csv
s3://test-bucket/glob_ss/a/t0.csv
s3://test-bucket/glob_ss/t0.csv

query I
query I rowsort
FROM GLOB('s3://test-bucket/glob_ss/*/**');
----
s3://test-bucket/glob_ss/a/b/t0.csv
s3://test-bucket/glob_ss/a/t0.csv

query I
query I rowsort
FROM GLOB('s3://test-bucket/glob_ss/a/**');
----
s3://test-bucket/glob_ss/a/b/t0.csv
Expand Down Expand Up @@ -221,7 +221,7 @@ SELECT COUNT(*) FROM GLOB('s3://test-bucket/glob_ss/partitioned/**/*');
----
10

query I
query I rowsort
FROM GLOB('s3://test-bucket/glob_ss/partitioned/**/*.parquet');
----
s3://test-bucket/glob_ss/partitioned/a=0/b=0/data_0.parquet
Expand Down
Loading