@@ -353,16 +353,13 @@ std::shared_ptr<Driver> Arbiter::getDriver(const std::string path) const
353
353
{
354
354
const auto type (getProtocol (path));
355
355
356
- {
357
- std::lock_guard<std::mutex> lock (m_mutex);
358
- auto it = m_drivers.find (type);
359
- if (it != m_drivers.end ()) return it->second ;
360
- }
356
+ std::lock_guard<std::mutex> lock (m_mutex);
357
+ auto it = m_drivers.find (type);
358
+ if (it != m_drivers.end ()) return it->second ;
361
359
362
360
const json config = getConfig (m_config);
363
361
if (auto driver = Driver::create (*m_pool, type, config.dump ()))
364
362
{
365
- std::lock_guard<std::mutex> lock (m_mutex);
366
363
m_drivers[type] = driver;
367
364
return driver;
368
365
}
@@ -1372,6 +1369,7 @@ LocalHandle::~LocalHandle()
1372
1369
#include < algorithm>
1373
1370
#include < cstring>
1374
1371
#include < iostream>
1372
+ #include < sstream>
1375
1373
1376
1374
#ifdef ARBITER_CUSTOM_NAMESPACE
1377
1375
namespace ARBITER_CUSTOM_NAMESPACE
@@ -1470,7 +1468,12 @@ std::vector<char> Http::getBinary(
1470
1468
std::vector<char > data;
1471
1469
if (!get (path, data, headers, query))
1472
1470
{
1473
- throw ArbiterError (" Could not read from " + path);
1471
+ std::stringstream oss;
1472
+ oss << " Could not read from '" << path << " '." ;
1473
+
1474
+ if (data.size ())
1475
+ oss << " Response message returned '" << std::string (data.data ()) << " '" ;
1476
+ throw ArbiterError (oss.str ());
1474
1477
}
1475
1478
return data;
1476
1479
}
@@ -1509,11 +1512,10 @@ bool Http::get(
1509
1512
auto http (m_pool.acquire ());
1510
1513
Response res (http.get (typedPath (path), headers, query));
1511
1514
1515
+
1516
+ data = res.data ();
1512
1517
if (res.ok ())
1513
- {
1514
- data = res.data ();
1515
1518
good = true ;
1516
- }
1517
1519
1518
1520
return good;
1519
1521
}
@@ -1867,9 +1869,58 @@ std::unique_ptr<S3::Auth> S3::Auth::create(
1867
1869
drivers::Http httpDriver (pool);
1868
1870
1869
1871
// Nothing found in the environment or on the filesystem. However we may
1870
- // be running in an EC2 instance with an instance profile set up.
1872
+ // be running in an EC2 instance with a service account or an instance profile set up.
1871
1873
try
1872
1874
{
1875
+ // Try to "assume role with web identity" - see https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
1876
+ try
1877
+ {
1878
+ const auto roleArn = env (" AWS_ROLE_ARN" );
1879
+ const auto webIdentityTokenFile = env (" AWS_WEB_IDENTITY_TOKEN_FILE" );
1880
+ std::unique_ptr<std::string> webIdentityToken;
1881
+
1882
+ if (roleArn && webIdentityTokenFile && (webIdentityToken = fsDriver.tryGet (*webIdentityTokenFile)))
1883
+ {
1884
+ // Decide on the STS root URL, defaults to https://sts.<AWS_REGION>.amazonaws.com
1885
+ std::string stsRootUrl;
1886
+ if (env (" AWS_STS_ROOT_URL" ))
1887
+ {
1888
+ stsRootUrl = *env (" AWS_STS_ROOT_URL" );
1889
+ }
1890
+ else
1891
+ {
1892
+ bool useRegionalEndpoint = env (" AWS_STS_REGIONAL_ENDPOINTS" )
1893
+ ? (*env (" AWS_STS_REGIONAL_ENDPOINTS" ) == " regional" )
1894
+ : true ;
1895
+ if (useRegionalEndpoint)
1896
+ {
1897
+ stsRootUrl = " https://sts." + S3::Config::extractRegion (s, profile) + " .amazonaws.com" ;
1898
+ }
1899
+ else
1900
+ {
1901
+ stsRootUrl = " https://sts.amazonaws.com" ;
1902
+ }
1903
+ }
1904
+
1905
+ const std::string roleSessionName = env (" AWS_ROLE_SESSION_NAME" ) ? *env (" AWS_ROLE_SESSION_NAME" ) : " pdal" ;
1906
+
1907
+ const std::string stsAssumeRoleWithWebIdentityUrl = stsRootUrl
1908
+ + " /?Action=AssumeRoleWithWebIdentity&Version=2011-06-15"
1909
+ + " &RoleSessionName=" + roleSessionName
1910
+ + " &RoleArn=" + *roleArn
1911
+ + " &WebIdentityToken=" + *webIdentityToken;
1912
+
1913
+ const auto res = httpDriver.internalGet (stsAssumeRoleWithWebIdentityUrl);
1914
+ if (!res.ok ())
1915
+ {
1916
+ throw ArbiterError (" Failed to assume role with web identity" );
1917
+ }
1918
+
1919
+ return makeUnique<Auth>(stsAssumeRoleWithWebIdentityUrl, ReauthMethod::ASSUME_ROLE_WITH_WEB_IDENTITY);
1920
+ }
1921
+ }
1922
+ catch (...) { }
1923
+
1873
1924
std::string token;
1874
1925
1875
1926
try
@@ -1911,8 +1962,8 @@ std::unique_ptr<S3::Auth> S3::Auth::create(
1911
1962
1912
1963
if (!iamRole.empty ())
1913
1964
{
1914
- const bool imdsv2 = !token.empty ();
1915
- return makeUnique<Auth>(ec2CredBase + " /" + iamRole, imdsv2 );
1965
+ const ReauthMethod reauthMethod = !token.empty () ? ReauthMethod::IMDS_V2 : ReauthMethod::IMDS_V1 ;
1966
+ return makeUnique<Auth>(ec2CredBase + " /" + iamRole, reauthMethod );
1916
1967
}
1917
1968
}
1918
1969
catch (...) { }
@@ -1921,7 +1972,7 @@ std::unique_ptr<S3::Auth> S3::Auth::create(
1921
1972
// different IP.
1922
1973
if (const auto relUri = env (" AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" ))
1923
1974
{
1924
- return makeUnique<Auth>(fargateCredIp + " /" + *relUri);
1975
+ return makeUnique<Auth>(fargateCredIp + " /" + *relUri, ReauthMethod::IMDS_V2 );
1925
1976
}
1926
1977
#endif
1927
1978
@@ -2038,15 +2089,15 @@ std::string S3::Config::extractBaseUrl(
2038
2089
for (const auto & partition : ep[" partitions" ])
2039
2090
{
2040
2091
if (
2041
- !partition.count (" regions" ) ||
2092
+ !partition.count (" regions" ) ||
2042
2093
!partition.at (" regions" ).count (region))
2043
2094
{
2044
2095
continue ;
2045
2096
}
2046
2097
2047
2098
// Look for an explicit hostname for this region/service.
2048
2099
if (
2049
- partition.count (" services" ) &&
2100
+ partition.count (" services" ) &&
2050
2101
partition[" services" ].count (" s3" ) &&
2051
2102
partition[" services" ][" s3" ].count (" endpoints" ))
2052
2103
{
@@ -2090,7 +2141,7 @@ S3::AuthFields S3::Auth::fields() const
2090
2141
2091
2142
std::string token;
2092
2143
2093
- if (m_imdsv2 )
2144
+ if (m_reauthMethod == ReauthMethod::IMDS_V2 )
2094
2145
{
2095
2146
try
2096
2147
{
@@ -2116,16 +2167,67 @@ S3::AuthFields S3::Auth::fields() const
2116
2167
http::Headers headers;
2117
2168
if (!token.empty ()) headers[" X-aws-ec2-metadata-token" ] = token;
2118
2169
2119
- const json creds = json::parse (
2120
- httpDriver.get (*m_credUrl, headers));
2170
+ const auto res = httpDriver.internalGet (*m_credUrl, headers);
2171
+ if (!res.ok ())
2172
+ {
2173
+ throw ArbiterError (" Failed to get token" );
2174
+ }
2175
+ std::vector<char > data (res.data ());
2176
+ data.push_back (' \0 ' );
2121
2177
2122
- m_access = creds.at (" AccessKeyId" ).get <std::string>();
2123
- m_hidden = creds.at (" SecretAccessKey" ).get <std::string>();
2124
- m_token = creds.at (" Token" ).get <std::string>();
2125
- m_expiration.reset (
2126
- new Time (
2127
- creds.at (" Expiration" ).get <std::string>(),
2128
- arbiter::Time::iso8601));
2178
+ if (m_reauthMethod == ReauthMethod::ASSUME_ROLE_WITH_WEB_IDENTITY)
2179
+ {
2180
+ // Parse XML response.
2181
+ Xml::xml_document<> xml;
2182
+ try
2183
+ {
2184
+ xml.parse <0 >(data.data ());
2185
+ }
2186
+ catch (Xml::parse_error&)
2187
+ {
2188
+ throw ArbiterError (" Could not parse S3 response." );
2189
+ }
2190
+ bool parsed = false ;
2191
+ if (XmlNode* topNode = xml.first_node (" AssumeRoleWithWebIdentityResponse" ))
2192
+ {
2193
+ if (XmlNode* resultNode = topNode->first_node (" AssumeRoleWithWebIdentityResult" ))
2194
+ {
2195
+ if (XmlNode* credsNode = resultNode->first_node (" Credentials" ))
2196
+ {
2197
+ XmlNode* accessNode = credsNode->first_node (" AccessKeyId" );
2198
+ XmlNode* hiddenNode = credsNode->first_node (" SecretAccessKey" );
2199
+ XmlNode* tokenNode = credsNode->first_node (" SessionToken" );
2200
+ XmlNode* expirationNode = credsNode->first_node (" Expiration" );
2201
+ if (accessNode && hiddenNode && tokenNode && expirationNode)
2202
+ {
2203
+ m_access = accessNode->value ();
2204
+ m_hidden = hiddenNode->value ();
2205
+ m_token = tokenNode->value ();
2206
+ m_expiration.reset (new Time (expirationNode->value (), arbiter::Time::iso8601));
2207
+ parsed = true ;
2208
+ }
2209
+ }
2210
+ }
2211
+ }
2212
+ if (!parsed)
2213
+ {
2214
+ throw ArbiterError (" Could not parse S3 response." );
2215
+ }
2216
+
2217
+ }
2218
+ else
2219
+ {
2220
+ // Parse JSON response.
2221
+ const json creds = json::parse (res.data ());
2222
+
2223
+ m_access = creds.at (" AccessKeyId" ).get <std::string>();
2224
+ m_hidden = creds.at (" SecretAccessKey" ).get <std::string>();
2225
+ m_token = creds.at (" Token" ).get <std::string>();
2226
+ m_expiration.reset (
2227
+ new Time (
2228
+ creds.at (" Expiration" ).get <std::string>(),
2229
+ arbiter::Time::iso8601));
2230
+ }
2129
2231
2130
2232
if (*m_expiration - now < reauthSeconds)
2131
2233
{
@@ -2205,9 +2307,10 @@ bool S3::get(
2205
2307
apiV4.query (),
2206
2308
size ? *size : 0 ));
2207
2309
2310
+ data = res.data ();
2311
+
2208
2312
if (res.ok ())
2209
2313
{
2210
- data = res.data ();
2211
2314
return true ;
2212
2315
}
2213
2316
0 commit comments