Skip to content

Commit 9d50ff6

Browse files
authored
feat: Correctly support custom domain in private apis (#3750)
1 parent e2109cd commit 9d50ff6

File tree

30 files changed

+3441
-5
lines changed

30 files changed

+3441
-5
lines changed

samtranslator/internal/schema_source/aws_serverless_api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,12 @@ class Route53(BaseModel):
154154
SetIdentifier: Optional[PassThroughProp] # TODO: add docs
155155
Region: Optional[PassThroughProp] # TODO: add docs
156156
SeparateRecordSetGroup: Optional[bool] # TODO: add docs
157+
VpcEndpointDomainName: Optional[PassThroughProp] # TODO: add docs
158+
VpcEndpointHostedZoneId: Optional[PassThroughProp] # TODO: add docs
159+
160+
161+
class AccessAssociation(BaseModel):
162+
VpcEndpointId: PassThroughProp # TODO: add docs
157163

158164

159165
class Domain(BaseModel):
@@ -185,6 +191,7 @@ class Domain(BaseModel):
185191
"SecurityPolicy",
186192
["AWS::ApiGateway::DomainName", "Properties", "SecurityPolicy"],
187193
)
194+
AccessAssociation: Optional[AccessAssociation]
188195

189196

190197
class DefinitionUri(BaseModel):
@@ -307,6 +314,7 @@ class Properties(BaseModel):
307314
OpenApiVersion: Optional[OpenApiVersion] = properties("OpenApiVersion")
308315
StageName: SamIntrinsicable[str] = properties("StageName")
309316
Tags: Optional[DictStrAny] = properties("Tags")
317+
Policy: Optional[PassThroughProp] # TODO: add docs
310318
PropagateTags: Optional[bool] # TODO: add docs
311319
TracingEnabled: Optional[TracingEnabled] = passthrough_prop(
312320
PROPERTIES_STEM,

samtranslator/model/api/api_generator.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ApiGatewayBasePathMappingV2,
1414
ApiGatewayDeployment,
1515
ApiGatewayDomainName,
16+
ApiGatewayDomainNameAccessAssociation,
1617
ApiGatewayDomainNameV2,
1718
ApiGatewayResponse,
1819
ApiGatewayRestApi,
@@ -86,6 +87,7 @@ class ApiDomainResponseV2:
8687
domain: Optional[ApiGatewayDomainNameV2]
8788
apigw_basepath_mapping_list: Optional[List[ApiGatewayBasePathMappingV2]]
8889
recordset_group: Any
90+
domain_access_association: Any
8991

9092

9193
class SharedApiUsagePlan:
@@ -218,6 +220,7 @@ def __init__( # noqa: PLR0913
218220
api_key_source_type: Optional[Intrinsicable[str]] = None,
219221
always_deploy: Optional[bool] = False,
220222
feature_toggle: Optional[FeatureToggle] = None,
223+
policy: Optional[Union[Dict[str, Any], Intrinsicable[str]]] = None,
221224
):
222225
"""Constructs an API Generator class that generates API Gateway resources
223226
@@ -275,6 +278,7 @@ def __init__( # noqa: PLR0913
275278
self.api_key_source_type = api_key_source_type
276279
self.always_deploy = always_deploy
277280
self.feature_toggle = feature_toggle
281+
self.policy = policy
278282

279283
def _construct_rest_api(self) -> ApiGatewayRestApi:
280284
"""Constructs and returns the ApiGateway RestApi.
@@ -328,6 +332,9 @@ def _construct_rest_api(self) -> ApiGatewayRestApi:
328332
if self.api_key_source_type:
329333
rest_api.ApiKeySourceType = self.api_key_source_type
330334

335+
if self.policy:
336+
rest_api.Policy = self.policy
337+
331338
return rest_api
332339

333340
def _validate_properties(self) -> None:
@@ -602,7 +609,7 @@ def _construct_api_domain_v2(
602609
Constructs and returns the ApiGateway Domain V2 and BasepathMapping V2
603610
"""
604611
if self.domain is None:
605-
return ApiDomainResponseV2(None, None, None)
612+
return ApiDomainResponseV2(None, None, None, None)
606613

607614
sam_expect(self.domain, self.logical_id, "Domain").to_be_a_map()
608615
domain_name: PassThrough = sam_expect(
@@ -657,6 +664,14 @@ def _construct_api_domain_v2(
657664
basepath_mapping.BasePath = path if normalize_basepath else basepath
658665
basepath_resource_list.extend([basepath_mapping])
659666

667+
# Create the DomainNameAccessAssociation
668+
domain_access_association = self.domain.get("AccessAssociation")
669+
domain_access_association_resource = None
670+
if domain_access_association is not None:
671+
domain_access_association_resource = self._generate_domain_access_association(
672+
domain_access_association, domain_name_arn, api_domain_name
673+
)
674+
660675
# Create the Route53 RecordSetGroup resource
661676
record_set_group = None
662677
route53 = self.domain.get("Route53")
@@ -683,6 +698,7 @@ def _construct_api_domain_v2(
683698
domain,
684699
basepath_resource_list,
685700
self._construct_single_record_set_group(self.domain, domain_name, route53),
701+
domain_access_association_resource,
686702
)
687703

688704
if not record_set_group:
@@ -691,7 +707,7 @@ def _construct_api_domain_v2(
691707

692708
record_set_group.RecordSets += self._construct_record_sets_for_domain(self.domain, domain_name, route53)
693709

694-
return ApiDomainResponseV2(domain, basepath_resource_list, record_set_group)
710+
return ApiDomainResponseV2(domain, basepath_resource_list, record_set_group, domain_access_association_resource)
695711

696712
def _get_basepaths(self) -> Optional[List[str]]:
697713
if self.domain is None:
@@ -779,11 +795,14 @@ def _construct_alias_target(self, domain: Dict[str, Any], api_domain_name: str,
779795
if domain.get("EndpointConfiguration") == "REGIONAL":
780796
alias_target["HostedZoneId"] = fnGetAtt(api_domain_name, "RegionalHostedZoneId")
781797
alias_target["DNSName"] = fnGetAtt(api_domain_name, "RegionalDomainName")
782-
else:
798+
elif domain.get("EndpointConfiguration") == "EDGE":
783799
if route53.get("DistributionDomainName") is None:
784800
route53["DistributionDomainName"] = fnGetAtt(api_domain_name, "DistributionDomainName")
785801
alias_target["HostedZoneId"] = "Z2FDTNDATAQYW2"
786802
alias_target["DNSName"] = route53.get("DistributionDomainName")
803+
else:
804+
alias_target["HostedZoneId"] = route53.get("VpcEndpointHostedZoneId")
805+
alias_target["DNSName"] = route53.get("VpcEndpointDomainName")
787806
return alias_target
788807

789808
def _create_basepath_mapping(
@@ -833,12 +852,17 @@ def to_cloudformation(
833852
domain: Union[Resource, None]
834853
basepath_mapping: Union[List[ApiGatewayBasePathMapping], List[ApiGatewayBasePathMappingV2], None]
835854
rest_api = self._construct_rest_api()
855+
is_private_domain = isinstance(self.domain, dict) and self.domain.get("EndpointConfiguration") == "PRIVATE"
836856
api_domain_response = (
837857
self._construct_api_domain_v2(rest_api, route53_record_set_groups)
838-
if isinstance(self.domain, dict) and self.domain.get("EndpointConfiguration") == "PRIVATE"
858+
if is_private_domain
839859
else self._construct_api_domain(rest_api, route53_record_set_groups)
840860
)
841861

862+
domain_access_association = None
863+
if is_private_domain:
864+
domain_access_association = cast(ApiDomainResponseV2, api_domain_response).domain_access_association
865+
842866
domain = api_domain_response.domain
843867
basepath_mapping = api_domain_response.apigw_basepath_mapping_list
844868

@@ -882,6 +906,9 @@ def to_cloudformation(
882906
]
883907
)
884908

909+
if domain_access_association is not None:
910+
generated_resources.append(domain_access_association)
911+
885912
# Make a list of single resources
886913
generated_resources_list: List[Resource] = []
887914
for resource in generated_resources:
@@ -1513,3 +1540,24 @@ def _set_endpoint_configuration(self, rest_api: ApiGatewayRestApi, value: Union[
15131540
else:
15141541
rest_api.EndpointConfiguration = {"Types": [value]}
15151542
rest_api.Parameters = {"endpointConfigurationTypes": value}
1543+
1544+
def _generate_domain_access_association(
1545+
self,
1546+
domain_access_association: Dict[str, Any],
1547+
domain_name_arn: Dict[str, str],
1548+
domain_logical_id: str,
1549+
) -> ApiGatewayDomainNameAccessAssociation:
1550+
"""
1551+
Generate domain access association resource
1552+
"""
1553+
vpcEndpointId = domain_access_association.get("VpcEndpointId")
1554+
logical_id = LogicalIdGenerator("DomainNameAccessAssociation", [vpcEndpointId, domain_logical_id]).gen()
1555+
1556+
domain_access_association_resource = ApiGatewayDomainNameAccessAssociation(
1557+
logical_id, attributes=self.passthrough_resource_attributes
1558+
)
1559+
domain_access_association_resource.DomainNameArn = domain_name_arn
1560+
domain_access_association_resource.AccessAssociationSourceType = "VPCE"
1561+
domain_access_association_resource.AccessAssociationSource = vpcEndpointId
1562+
1563+
return domain_access_association_resource

samtranslator/model/apigateway.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class ApiGatewayRestApi(Resource):
2929
"Mode": GeneratedProperty(),
3030
"ApiKeySourceType": GeneratedProperty(),
3131
"Tags": GeneratedProperty(),
32+
"Policy": GeneratedProperty(),
3233
}
3334

3435
Body: Optional[Dict[str, Any]]
@@ -44,6 +45,7 @@ class ApiGatewayRestApi(Resource):
4445
Mode: Optional[PassThrough]
4546
ApiKeySourceType: Optional[PassThrough]
4647
Tags: Optional[PassThrough]
48+
Policy: Optional[PassThrough]
4749

4850
runtime_attrs = {"rest_api_id": lambda self: ref(self.logical_id)}
4951

@@ -307,6 +309,16 @@ class ApiGatewayApiKey(Resource):
307309
runtime_attrs = {"api_key_id": lambda self: ref(self.logical_id)}
308310

309311

312+
class ApiGatewayDomainNameAccessAssociation(Resource):
313+
resource_type = "AWS::ApiGateway::DomainNameAccessAssociation"
314+
property_types = {
315+
"AccessAssociationSource": GeneratedProperty(),
316+
"AccessAssociationSourceType": GeneratedProperty(),
317+
"DomainNameArn": GeneratedProperty(),
318+
"Tags": GeneratedProperty(),
319+
}
320+
321+
310322
class ApiGatewayAuthorizer:
311323
_VALID_FUNCTION_PAYLOAD_TYPES = [None, "TOKEN", "REQUEST"]
312324

samtranslator/model/sam_resources.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" SAM macro definitions """
1+
""" SAM macro definitions """
22

33
import copy
44
from contextlib import suppress
@@ -1275,6 +1275,7 @@ class SamApi(SamResourceMacro):
12751275
"DisableExecuteApiEndpoint": PropertyType(False, IS_BOOL),
12761276
"ApiKeySourceType": PropertyType(False, IS_STR),
12771277
"AlwaysDeploy": Property(False, IS_BOOL),
1278+
"Policy": PropertyType(False, one_of(IS_STR, IS_DICT)),
12781279
}
12791280

12801281
Name: Optional[Intrinsicable[str]]
@@ -1306,6 +1307,7 @@ class SamApi(SamResourceMacro):
13061307
DisableExecuteApiEndpoint: Optional[Intrinsicable[bool]]
13071308
ApiKeySourceType: Optional[Intrinsicable[str]]
13081309
AlwaysDeploy: Optional[bool]
1310+
Policy: Optional[Union[Dict[str, Any], Intrinsicable[str]]]
13091311

13101312
referable_properties = {
13111313
"Stage": ApiGatewayStage.resource_type,
@@ -1373,6 +1375,7 @@ def to_cloudformation(self, **kwargs) -> List[Resource]: # type: ignore[no-unty
13731375
api_key_source_type=self.ApiKeySourceType,
13741376
always_deploy=self.AlwaysDeploy,
13751377
feature_toggle=feature_toggle,
1378+
policy=self.Policy,
13761379
)
13771380

13781381
generated_resources = api_generator.to_cloudformation(redeploy_restapi_parameters, route53_record_set_groups)

samtranslator/schema/schema.json

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273504,6 +273504,19 @@
273504273504
],
273505273505
"type": "object"
273506273506
},
273507+
"AccessAssociation": {
273508+
"additionalProperties": false,
273509+
"properties": {
273510+
"VpcEndpointId": {
273511+
"$ref": "#/definitions/PassThroughProp"
273512+
}
273513+
},
273514+
"required": [
273515+
"VpcEndpointId"
273516+
],
273517+
"title": "AccessAssociation",
273518+
"type": "object"
273519+
},
273507273520
"Alexa::ASK::Skill": {
273508273521
"additionalProperties": false,
273509273522
"properties": {
@@ -277223,6 +277236,9 @@
277223277236
"samtranslator__internal__schema_source__aws_serverless_api__Domain": {
277224277237
"additionalProperties": false,
277225277238
"properties": {
277239+
"AccessAssociation": {
277240+
"$ref": "#/definitions/AccessAssociation"
277241+
},
277226277242
"BasePath": {
277227277243
"allOf": [
277228277244
{
@@ -277629,6 +277645,9 @@
277629277645
"markdownDescription": "Version of OpenApi to use\\. This can either be `2.0` for the Swagger specification, or one of the OpenApi 3\\.0 versions, like `3.0.1`\\. For more information about OpenAPI, see the [OpenAPI Specification](https://swagger.io/specification/)\\. \n AWS SAM creates a stage called `Stage` by default\\. Setting this property to any valid value will prevent the creation of the stage `Stage`\\. \n*Type*: String \n*Required*: No \n*AWS CloudFormation compatibility*: This property is unique to AWS SAM and doesn't have an AWS CloudFormation equivalent\\.",
277630277646
"title": "OpenApiVersion"
277631277647
},
277648+
"Policy": {
277649+
"$ref": "#/definitions/PassThroughProp"
277650+
},
277632277651
"PropagateTags": {
277633277652
"title": "Propagatetags",
277634277653
"type": "boolean"
@@ -277939,6 +277958,12 @@
277939277958
},
277940277959
"SetIdentifier": {
277941277960
"$ref": "#/definitions/PassThroughProp"
277961+
},
277962+
"VpcEndpointDomainName": {
277963+
"$ref": "#/definitions/PassThroughProp"
277964+
},
277965+
"VpcEndpointHostedZoneId": {
277966+
"$ref": "#/definitions/PassThroughProp"
277942277967
}
277943277968
},
277944277969
"title": "Route53",

schema_source/sam.schema.json

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
{
22
"$schema": "http://json-schema.org/draft-04/schema#",
33
"definitions": {
4+
"AccessAssociation": {
5+
"additionalProperties": false,
6+
"properties": {
7+
"VpcEndpointId": {
8+
"$ref": "#/definitions/PassThroughProp"
9+
}
10+
},
11+
"required": [
12+
"VpcEndpointId"
13+
],
14+
"title": "AccessAssociation",
15+
"type": "object"
16+
},
417
"AlexaSkillEvent": {
518
"additionalProperties": false,
619
"properties": {
@@ -3573,6 +3586,9 @@
35733586
"samtranslator__internal__schema_source__aws_serverless_api__Domain": {
35743587
"additionalProperties": false,
35753588
"properties": {
3589+
"AccessAssociation": {
3590+
"$ref": "#/definitions/AccessAssociation"
3591+
},
35763592
"BasePath": {
35773593
"allOf": [
35783594
{
@@ -4292,6 +4308,9 @@
42924308
"markdownDescription": "Version of OpenApi to use\\. This can either be `2.0` for the Swagger specification, or one of the OpenApi 3\\.0 versions, like `3.0.1`\\. For more information about OpenAPI, see the [OpenAPI Specification](https://swagger.io/specification/)\\. \n AWS SAM creates a stage called `Stage` by default\\. Setting this property to any valid value will prevent the creation of the stage `Stage`\\. \n*Type*: String \n*Required*: No \n*AWS CloudFormation compatibility*: This property is unique to AWS SAM and doesn't have an AWS CloudFormation equivalent\\.",
42934309
"title": "OpenApiVersion"
42944310
},
4311+
"Policy": {
4312+
"$ref": "#/definitions/PassThroughProp"
4313+
},
42954314
"PropagateTags": {
42964315
"title": "Propagatetags",
42974316
"type": "boolean"
@@ -4672,6 +4691,12 @@
46724691
},
46734692
"SetIdentifier": {
46744693
"$ref": "#/definitions/PassThroughProp"
4694+
},
4695+
"VpcEndpointDomainName": {
4696+
"$ref": "#/definitions/PassThroughProp"
4697+
},
4698+
"VpcEndpointHostedZoneId": {
4699+
"$ref": "#/definitions/PassThroughProp"
46754700
}
46764701
},
46774702
"title": "Route53",

0 commit comments

Comments
 (0)