Skip to content

Commit c9cf0d9

Browse files
authored
Vendor downstream (PDAL#4739)
* Fix race condition connormanning/arbiter@5c3f36e * AssumeRoleWithWebIdentity connormanning/arbiter#55 * Return data even when return code is not 200 connormanning/arbiter#59
1 parent 3e03c25 commit c9cf0d9

File tree

2 files changed

+143
-32
lines changed

2 files changed

+143
-32
lines changed

vendor/arbiter/arbiter.cpp

Lines changed: 130 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -353,16 +353,13 @@ std::shared_ptr<Driver> Arbiter::getDriver(const std::string path) const
353353
{
354354
const auto type(getProtocol(path));
355355

356-
{
357-
std::lock_guard<std::mutex> lock(m_mutex);
358-
auto it = m_drivers.find(type);
359-
if (it != m_drivers.end()) return it->second;
360-
}
356+
std::lock_guard<std::mutex> lock(m_mutex);
357+
auto it = m_drivers.find(type);
358+
if (it != m_drivers.end()) return it->second;
361359

362360
const json config = getConfig(m_config);
363361
if (auto driver = Driver::create(*m_pool, type, config.dump()))
364362
{
365-
std::lock_guard<std::mutex> lock(m_mutex);
366363
m_drivers[type] = driver;
367364
return driver;
368365
}
@@ -1372,6 +1369,7 @@ LocalHandle::~LocalHandle()
13721369
#include <algorithm>
13731370
#include <cstring>
13741371
#include <iostream>
1372+
#include <sstream>
13751373

13761374
#ifdef ARBITER_CUSTOM_NAMESPACE
13771375
namespace ARBITER_CUSTOM_NAMESPACE
@@ -1470,7 +1468,12 @@ std::vector<char> Http::getBinary(
14701468
std::vector<char> data;
14711469
if (!get(path, data, headers, query))
14721470
{
1473-
throw ArbiterError("Could not read from " + path);
1471+
std::stringstream oss;
1472+
oss << "Could not read from '" << path << "'.";
1473+
1474+
if (data.size())
1475+
oss << " Response message returned '" << std::string(data.data()) << "'";
1476+
throw ArbiterError(oss.str());
14741477
}
14751478
return data;
14761479
}
@@ -1509,11 +1512,10 @@ bool Http::get(
15091512
auto http(m_pool.acquire());
15101513
Response res(http.get(typedPath(path), headers, query));
15111514

1515+
1516+
data = res.data();
15121517
if (res.ok())
1513-
{
1514-
data = res.data();
15151518
good = true;
1516-
}
15171519

15181520
return good;
15191521
}
@@ -1867,9 +1869,58 @@ std::unique_ptr<S3::Auth> S3::Auth::create(
18671869
drivers::Http httpDriver(pool);
18681870

18691871
// Nothing found in the environment or on the filesystem. However we may
1870-
// be running in an EC2 instance with an instance profile set up.
1872+
// be running in an EC2 instance with a service account or an instance profile set up.
18711873
try
18721874
{
1875+
// Try to "assume role with web identity" - see https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
1876+
try
1877+
{
1878+
const auto roleArn = env("AWS_ROLE_ARN");
1879+
const auto webIdentityTokenFile = env("AWS_WEB_IDENTITY_TOKEN_FILE");
1880+
std::unique_ptr<std::string> webIdentityToken;
1881+
1882+
if (roleArn && webIdentityTokenFile && (webIdentityToken = fsDriver.tryGet(*webIdentityTokenFile)))
1883+
{
1884+
// Decide on the STS root URL, defaults to https://sts.<AWS_REGION>.amazonaws.com
1885+
std::string stsRootUrl;
1886+
if (env("AWS_STS_ROOT_URL"))
1887+
{
1888+
stsRootUrl = *env("AWS_STS_ROOT_URL");
1889+
}
1890+
else
1891+
{
1892+
bool useRegionalEndpoint = env("AWS_STS_REGIONAL_ENDPOINTS")
1893+
? (*env("AWS_STS_REGIONAL_ENDPOINTS") == "regional")
1894+
: true;
1895+
if (useRegionalEndpoint)
1896+
{
1897+
stsRootUrl = "https://sts." + S3::Config::extractRegion(s, profile) + ".amazonaws.com";
1898+
}
1899+
else
1900+
{
1901+
stsRootUrl = "https://sts.amazonaws.com";
1902+
}
1903+
}
1904+
1905+
const std::string roleSessionName = env("AWS_ROLE_SESSION_NAME") ? *env("AWS_ROLE_SESSION_NAME") : "pdal";
1906+
1907+
const std::string stsAssumeRoleWithWebIdentityUrl = stsRootUrl
1908+
+ "/?Action=AssumeRoleWithWebIdentity&Version=2011-06-15"
1909+
+ "&RoleSessionName=" + roleSessionName
1910+
+ "&RoleArn=" + *roleArn
1911+
+ "&WebIdentityToken=" + *webIdentityToken;
1912+
1913+
const auto res = httpDriver.internalGet(stsAssumeRoleWithWebIdentityUrl);
1914+
if (!res.ok())
1915+
{
1916+
throw ArbiterError("Failed to assume role with web identity");
1917+
}
1918+
1919+
return makeUnique<Auth>(stsAssumeRoleWithWebIdentityUrl, ReauthMethod::ASSUME_ROLE_WITH_WEB_IDENTITY);
1920+
}
1921+
}
1922+
catch (...) { }
1923+
18731924
std::string token;
18741925

18751926
try
@@ -1911,8 +1962,8 @@ std::unique_ptr<S3::Auth> S3::Auth::create(
19111962

19121963
if (!iamRole.empty())
19131964
{
1914-
const bool imdsv2 = !token.empty();
1915-
return makeUnique<Auth>(ec2CredBase + "/" + iamRole, imdsv2);
1965+
const ReauthMethod reauthMethod = !token.empty() ? ReauthMethod::IMDS_V2 : ReauthMethod::IMDS_V1;
1966+
return makeUnique<Auth>(ec2CredBase + "/" + iamRole, reauthMethod);
19161967
}
19171968
}
19181969
catch (...) { }
@@ -1921,7 +1972,7 @@ std::unique_ptr<S3::Auth> S3::Auth::create(
19211972
// different IP.
19221973
if (const auto relUri = env("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"))
19231974
{
1924-
return makeUnique<Auth>(fargateCredIp + "/" + *relUri);
1975+
return makeUnique<Auth>(fargateCredIp + "/" + *relUri, ReauthMethod::IMDS_V2);
19251976
}
19261977
#endif
19271978

@@ -2038,15 +2089,15 @@ std::string S3::Config::extractBaseUrl(
20382089
for (const auto& partition : ep["partitions"])
20392090
{
20402091
if (
2041-
!partition.count("regions") ||
2092+
!partition.count("regions") ||
20422093
!partition.at("regions").count(region))
20432094
{
20442095
continue;
20452096
}
20462097

20472098
// Look for an explicit hostname for this region/service.
20482099
if (
2049-
partition.count("services") &&
2100+
partition.count("services") &&
20502101
partition["services"].count("s3") &&
20512102
partition["services"]["s3"].count("endpoints"))
20522103
{
@@ -2090,7 +2141,7 @@ S3::AuthFields S3::Auth::fields() const
20902141

20912142
std::string token;
20922143

2093-
if (m_imdsv2)
2144+
if (m_reauthMethod == ReauthMethod::IMDS_V2)
20942145
{
20952146
try
20962147
{
@@ -2116,16 +2167,67 @@ S3::AuthFields S3::Auth::fields() const
21162167
http::Headers headers;
21172168
if (!token.empty()) headers["X-aws-ec2-metadata-token"] = token;
21182169

2119-
const json creds = json::parse(
2120-
httpDriver.get(*m_credUrl, headers));
2170+
const auto res = httpDriver.internalGet(*m_credUrl, headers);
2171+
if (!res.ok())
2172+
{
2173+
throw ArbiterError("Failed to get token");
2174+
}
2175+
std::vector<char> data(res.data());
2176+
data.push_back('\0');
21212177

2122-
m_access = creds.at("AccessKeyId").get<std::string>();
2123-
m_hidden = creds.at("SecretAccessKey").get<std::string>();
2124-
m_token = creds.at("Token").get<std::string>();
2125-
m_expiration.reset(
2126-
new Time(
2127-
creds.at("Expiration").get<std::string>(),
2128-
arbiter::Time::iso8601));
2178+
if (m_reauthMethod == ReauthMethod::ASSUME_ROLE_WITH_WEB_IDENTITY)
2179+
{
2180+
// Parse XML response.
2181+
Xml::xml_document<> xml;
2182+
try
2183+
{
2184+
xml.parse<0>(data.data());
2185+
}
2186+
catch (Xml::parse_error&)
2187+
{
2188+
throw ArbiterError("Could not parse S3 response.");
2189+
}
2190+
bool parsed = false;
2191+
if (XmlNode* topNode = xml.first_node("AssumeRoleWithWebIdentityResponse"))
2192+
{
2193+
if (XmlNode* resultNode = topNode->first_node("AssumeRoleWithWebIdentityResult"))
2194+
{
2195+
if (XmlNode* credsNode = resultNode->first_node("Credentials"))
2196+
{
2197+
XmlNode* accessNode = credsNode->first_node("AccessKeyId");
2198+
XmlNode* hiddenNode = credsNode->first_node("SecretAccessKey");
2199+
XmlNode* tokenNode = credsNode->first_node("SessionToken");
2200+
XmlNode* expirationNode = credsNode->first_node("Expiration");
2201+
if (accessNode && hiddenNode && tokenNode && expirationNode)
2202+
{
2203+
m_access = accessNode->value();
2204+
m_hidden = hiddenNode->value();
2205+
m_token = tokenNode->value();
2206+
m_expiration.reset(new Time(expirationNode->value(), arbiter::Time::iso8601));
2207+
parsed = true;
2208+
}
2209+
}
2210+
}
2211+
}
2212+
if (!parsed)
2213+
{
2214+
throw ArbiterError("Could not parse S3 response.");
2215+
}
2216+
2217+
}
2218+
else
2219+
{
2220+
// Parse JSON response.
2221+
const json creds = json::parse(res.data());
2222+
2223+
m_access = creds.at("AccessKeyId").get<std::string>();
2224+
m_hidden = creds.at("SecretAccessKey").get<std::string>();
2225+
m_token = creds.at("Token").get<std::string>();
2226+
m_expiration.reset(
2227+
new Time(
2228+
creds.at("Expiration").get<std::string>(),
2229+
arbiter::Time::iso8601));
2230+
}
21292231

21302232
if (*m_expiration - now < reauthSeconds)
21312233
{
@@ -2205,9 +2307,10 @@ bool S3::get(
22052307
apiV4.query(),
22062308
size ? *size : 0));
22072309

2310+
data = res.data();
2311+
22082312
if (res.ok())
22092313
{
2210-
data = res.data();
22112314
return true;
22122315
}
22132316

vendor/arbiter/arbiter.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/// Arbiter amalgamated header (https://github.com/connormanning/arbiter).
22
/// It is intended to be used with #include "arbiter.hpp"
33

4-
// Git SHA: 4d61535996946414317c6617724c20a2bc4d9bbf
4+
// Git SHA: 9671686bd5cc1a5f2d56075f93d77d80c0c72a06
55

66
// //////////////////////////////////////////////////////////////////////
77
// Beginning of content of file: LICENSE
@@ -45,7 +45,6 @@ SOFTWARE.
4545
/// If defined, indicates that the source file is amalgamated
4646
/// to prevent private header inclusion.
4747
#define ARBITER_IS_AMALGAMATION
48-
#define ARBITER_CUSTOM_NAMESPACE pdal
4948

5049
// //////////////////////////////////////////////////////////////////////
5150
// Beginning of content of file: arbiter/third/xml/rapidxml.hpp
@@ -4225,6 +4224,12 @@ namespace arbiter
42254224
namespace drivers
42264225
{
42274226

4227+
enum class ReauthMethod {
4228+
ASSUME_ROLE_WITH_WEB_IDENTITY,
4229+
IMDS_V1,
4230+
IMDS_V2,
4231+
};
4232+
42284233
/** @brief Amazon %S3 driver. */
42294234
class S3 : public Http
42304235
{
@@ -4245,6 +4250,7 @@ class S3 : public Http
42454250
* - JSON configuration
42464251
* - Well-known files or their environment overrides, like
42474252
* `~/.aws/credentials` or the file at AWS_CREDENTIAL_FILE.
4253+
* - STS assume role with web identity.
42484254
* - EC2 instance profile.
42494255
*/
42504256
static std::unique_ptr<S3> create(
@@ -4316,9 +4322,9 @@ class S3::Auth
43164322
, m_token(token)
43174323
{ }
43184324

4319-
Auth(std::string credUrl, bool imdsv2 = true)
4325+
Auth(std::string credUrl, ReauthMethod reauthMethod)
43204326
: m_credUrl(internal::makeUnique<std::string>(credUrl))
4321-
, m_imdsv2(imdsv2)
4327+
, m_reauthMethod(reauthMethod)
43224328
{ }
43234329

43244330
static std::unique_ptr<Auth> create(std::string profile, std::string s);
@@ -4331,7 +4337,7 @@ class S3::Auth
43314337
mutable std::string m_token;
43324338

43334339
std::unique_ptr<std::string> m_credUrl;
4334-
bool m_imdsv2 = true;
4340+
ReauthMethod m_reauthMethod;
43354341
mutable std::unique_ptr<Time> m_expiration;
43364342
mutable std::mutex m_mutex;
43374343
};
@@ -4354,6 +4360,8 @@ class S3::Config
43544360
const std::string m_baseUrl;
43554361
http::Headers m_baseHeaders;
43564362
bool m_precheck;
4363+
4364+
friend class S3;
43574365
};
43584366

43594367

0 commit comments

Comments
 (0)