Skip to content

Commit 977716c

Browse files
doy-striperyan-lane
authored andcommitted
allow overriding endpoint_url (#11)
* allow overriding endpoint_url this allows support for using a kms reverse proxy * fix caching when specifying endpoint_url
1 parent ab4cc41 commit 977716c

File tree

2 files changed

+40
-12
lines changed

2 files changed

+40
-12
lines changed

kmsauth/__init__.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def __init__(
3434
maximum_token_version=2,
3535
auth_token_max_lifetime=60,
3636
aws_creds=None,
37-
extra_context=None
37+
extra_context=None,
38+
endpoint_url=None
3839
):
3940
"""Create a KMSTokenValidator object.
4041
@@ -54,6 +55,8 @@ def __init__(
5455
aws_creds: A dict of AccessKeyId, SecretAccessKey, SessionToken.
5556
Useful if you wish to pass in assumed role credentials or MFA
5657
credentials. Default: None
58+
endpoint_url: A URL to override the default endpoint used to access
59+
the KMS service. Default: None
5760
"""
5861
self.auth_key = auth_key
5962
self.user_auth_key = user_auth_key
@@ -73,12 +76,14 @@ def __init__(
7376
region=self.region,
7477
aws_access_key_id=self.aws_creds['AccessKeyId'],
7578
aws_secret_access_key=self.aws_creds['SecretAccessKey'],
76-
aws_session_token=self.aws_creds['SessionToken']
79+
aws_session_token=self.aws_creds['SessionToken'],
80+
endpoint_url=endpoint_url
7781
)
7882
else:
7983
self.kms_client = kmsauth.services.get_boto_client(
8084
'kms',
81-
region=self.region
85+
region=self.region,
86+
endpoint_url=endpoint_url
8287
)
8388
if extra_context is None:
8489
self.extra_context = {}
@@ -308,7 +313,8 @@ def __init__(
308313
token_version=2,
309314
token_cache_file=None,
310315
token_lifetime=10,
311-
aws_creds=None
316+
aws_creds=None,
317+
endpoint_url=None
312318
):
313319
"""Create a KMSTokenGenerator object.
314320
@@ -326,6 +332,8 @@ def __init__(
326332
aws_creds: A dict of AccessKeyId, SecretAccessKey, SessionToken.
327333
Useful if you wish to pass in assumed role credentials or MFA
328334
credentials. Default: None
335+
endpoint_url: A URL to override the default endpoint used to access
336+
the KMS service. Default: None
329337
"""
330338
self.auth_key = auth_key
331339
if auth_context is None:
@@ -343,12 +351,14 @@ def __init__(
343351
region=self.region,
344352
aws_access_key_id=self.aws_creds['AccessKeyId'],
345353
aws_secret_access_key=self.aws_creds['SecretAccessKey'],
346-
aws_session_token=self.aws_creds['SessionToken']
354+
aws_session_token=self.aws_creds['SessionToken'],
355+
endpoint_url=endpoint_url
347356
)
348357
else:
349358
self.kms_client = kmsauth.services.get_boto_client(
350359
'kms',
351-
region=self.region
360+
region=self.region,
361+
endpoint_url=endpoint_url
352362
)
353363
self._validate()
354364

kmsauth/services.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,16 @@ def get_boto_client(
1212
region=None,
1313
aws_access_key_id=None,
1414
aws_secret_access_key=None,
15-
aws_session_token=None
15+
aws_session_token=None,
16+
endpoint_url=None
1617
):
1718
"""Get a boto3 client connection."""
18-
cache_key = '{0}:{1}:{2}'.format(client, region, aws_access_key_id)
19+
cache_key = '{0}:{1}:{2}:{3}'.format(
20+
client,
21+
region,
22+
aws_access_key_id,
23+
endpoint_url or ''
24+
)
1925
if not aws_session_token:
2026
if cache_key in CLIENT_CACHE:
2127
return CLIENT_CACHE[cache_key]
@@ -29,7 +35,10 @@ def get_boto_client(
2935
logging.error("Failed to get {0} client.".format(client))
3036
return None
3137

32-
CLIENT_CACHE[cache_key] = session.client(client)
38+
CLIENT_CACHE[cache_key] = session.client(
39+
client,
40+
endpoint_url=endpoint_url
41+
)
3342
return CLIENT_CACHE[cache_key]
3443

3544

@@ -38,10 +47,16 @@ def get_boto_resource(
3847
region=None,
3948
aws_access_key_id=None,
4049
aws_secret_access_key=None,
41-
aws_session_token=None
50+
aws_session_token=None,
51+
endpoint_url=None
4252
):
4353
"""Get a boto resource connection."""
44-
cache_key = '{0}:{1}:{2}'.format(resource, region, aws_access_key_id)
54+
cache_key = '{0}:{1}:{2}:{3}'.format(
55+
resource,
56+
region,
57+
aws_access_key_id,
58+
endpoint_url or ''
59+
)
4560
if not aws_session_token:
4661
if cache_key in RESOURCE_CACHE:
4762
return RESOURCE_CACHE[cache_key]
@@ -55,7 +70,10 @@ def get_boto_resource(
5570
logging.error("Failed to get {0} resource.".format(resource))
5671
return None
5772

58-
RESOURCE_CACHE[cache_key] = session.resource(resource)
73+
RESOURCE_CACHE[cache_key] = session.resource(
74+
resource,
75+
endpoint_url=endpoint_url
76+
)
5977
return RESOURCE_CACHE[cache_key]
6078

6179

0 commit comments

Comments
 (0)