Skip to content

Commit 19fd6d8

Browse files
authored
refactor: Combine function/sfn Api event logics about adding Auth (#2755)
1 parent 36afb70 commit 19fd6d8

File tree

11 files changed

+172
-204
lines changed

11 files changed

+172
-204
lines changed

samtranslator/model/api/api_generator.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ def _add_auth(self) -> None:
753753
auth_properties.ResourcePolicy, "ResourcePolicy must be a map (ResourcePolicyStatement)."
754754
)
755755
for path in swagger_editor.iter_on_path():
756-
swagger_editor.add_resource_policy(auth_properties.ResourcePolicy, path, self.stage_name) # type: ignore[no-untyped-call]
756+
swagger_editor.add_resource_policy(auth_properties.ResourcePolicy, path, self.stage_name)
757757
if auth_properties.ResourcePolicy.get("CustomStatements"):
758758
swagger_editor.add_custom_statements(auth_properties.ResourcePolicy.get("CustomStatements")) # type: ignore[no-untyped-call]
759759

@@ -1120,14 +1120,10 @@ def _get_authorizers(self, authorizers_config, default_authorizer=None): # type
11201120
return authorizers
11211121
return None
11221122

1123-
if not isinstance(authorizers_config, dict):
1124-
raise InvalidResourceException(self.logical_id, "Authorizers must be a dictionary.")
1123+
sam_expect(authorizers_config, self.logical_id, "Auth.Authorizers").to_be_a_map()
11251124

11261125
for authorizer_name, authorizer in authorizers_config.items():
1127-
if not isinstance(authorizer, dict):
1128-
raise InvalidResourceException(
1129-
self.logical_id, "Authorizer %s must be a dictionary." % (authorizer_name)
1130-
)
1126+
sam_expect(authorizer, self.logical_id, f"Auth.Authorizers.{authorizer_name}").to_be_a_map()
11311127

11321128
authorizers[authorizer_name] = ApiGatewayAuthorizer( # type: ignore[no-untyped-call]
11331129
api_logical_id=self.logical_id,

samtranslator/model/api/http_api_generator.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -563,14 +563,10 @@ def _get_authorizers(
563563
if not authorizers_config:
564564
return authorizers
565565

566-
if not isinstance(authorizers_config, dict):
567-
raise InvalidResourceException(self.logical_id, "Authorizers must be a dictionary.")
566+
sam_expect(authorizers_config, self.logical_id, "Auth.Authorizers").to_be_a_map()
568567

569568
for authorizer_name, authorizer in authorizers_config.items():
570-
if not isinstance(authorizer, dict):
571-
raise InvalidResourceException(
572-
self.logical_id, "Authorizer %s must be a dictionary." % (authorizer_name)
573-
)
569+
sam_expect(authorizer, self.logical_id, f"Auth.Authorizers.{authorizer_name}").to_be_a_map()
574570

575571
if "OpenIdConnectUrl" in authorizer:
576572
raise InvalidResourceException(

samtranslator/model/eventsources/push.py

Lines changed: 121 additions & 91 deletions
Large diffs are not rendered by default.

samtranslator/model/stepfunctions/events.py

Lines changed: 25 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict, Optional, cast
33

44
from samtranslator.metrics.method_decorator import cw_timer
55
from samtranslator.model import Property, PropertyType, ResourceMacro, Resource
@@ -27,6 +27,7 @@ class EventSource(ResourceMacro):
2727
# line to avoid any potential behavior change.
2828
# TODO: Make `EventSource` an abstract class and not giving `principal` initial value.
2929
principal: str = None # type: ignore
30+
relative_id: str # overriding the Optional[str]: for event, relative id is not None
3031

3132
Target: Optional[Dict[str, str]]
3233

@@ -272,6 +273,11 @@ class Api(EventSource):
272273
"UnescapeMappingTemplate": Property(False, is_type(bool)),
273274
}
274275

276+
Path: str
277+
Method: str
278+
RestApiId: str
279+
Stage: Optional[str]
280+
Auth: Optional[Dict[str, Any]]
275281
UnescapeMappingTemplate: Optional[bool]
276282

277283
def resources_to_link(self, resources): # type: ignore[no-untyped-def]
@@ -289,7 +295,7 @@ def resources_to_link(self, resources): # type: ignore[no-untyped-def]
289295
permitted_stage = "*"
290296
stage_suffix = "AllStages"
291297
explicit_api = None
292-
rest_api_id = PushApi.get_rest_api_id_string(self.RestApiId) # type: ignore[attr-defined]
298+
rest_api_id = PushApi.get_rest_api_id_string(self.RestApiId)
293299
if isinstance(rest_api_id, str):
294300

295301
if (
@@ -314,7 +320,7 @@ def resources_to_link(self, resources): # type: ignore[no-untyped-def]
314320
"RestApiId property of Api event must reference a valid resource in the same template.",
315321
)
316322

317-
return {"explicit_api": explicit_api, "explicit_api_stage": {"suffix": stage_suffix}}
323+
return {"explicit_api": explicit_api, "api_id": rest_api_id, "explicit_api_stage": {"suffix": stage_suffix}}
318324

319325
@cw_timer(prefix=SFN_EVETSOURCE_METRIC_PREFIX)
320326
def to_cloudformation(self, resource, **kwargs): # type: ignore[no-untyped-def]
@@ -336,20 +342,21 @@ def to_cloudformation(self, resource, **kwargs): # type: ignore[no-untyped-def]
336342
intrinsics_resolver = kwargs.get("intrinsics_resolver")
337343
permissions_boundary = kwargs.get("permissions_boundary")
338344

339-
if self.Method is not None: # type: ignore[has-type]
345+
if self.Method is not None:
340346
# Convert to lower case so that user can specify either GET or get
341-
self.Method = self.Method.lower() # type: ignore[has-type]
347+
self.Method = self.Method.lower()
342348

343349
role = self._construct_role(resource, permissions_boundary) # type: ignore[no-untyped-call]
344350
resources.append(role)
345351

346352
explicit_api = kwargs["explicit_api"]
353+
api_id = kwargs["api_id"]
347354
if explicit_api.get("__MANAGE_SWAGGER"):
348-
self._add_swagger_integration(explicit_api, resource, role, intrinsics_resolver) # type: ignore[no-untyped-call]
355+
self._add_swagger_integration(explicit_api, api_id, resource, role, intrinsics_resolver) # type: ignore[no-untyped-call]
349356

350357
return resources
351358

352-
def _add_swagger_integration(self, api, resource, role, intrinsics_resolver): # type: ignore[no-untyped-def]
359+
def _add_swagger_integration(self, api, api_id, resource, role, intrinsics_resolver): # type: ignore[no-untyped-def]
353360
"""Adds the path and method for this Api event source to the Swagger body for the provided RestApi.
354361
355362
:param model.apigateway.ApiGatewayRestApi rest_api: the RestApi to which the path and method should be added.
@@ -362,12 +369,12 @@ def _add_swagger_integration(self, api, resource, role, intrinsics_resolver): #
362369

363370
editor = SwaggerEditor(swagger_body)
364371

365-
if editor.has_integration(self.Path, self.Method): # type: ignore[attr-defined]
372+
if editor.has_integration(self.Path, self.Method):
366373
# Cannot add the integration, if it is already present
367374
raise InvalidEventException(
368375
self.relative_id,
369376
'API method "{method}" defined multiple times for path "{path}".'.format(
370-
method=self.Method, path=self.Path # type: ignore[attr-defined]
377+
method=self.Method, path=self.Path
371378
),
372379
)
373380

@@ -382,78 +389,23 @@ def _add_swagger_integration(self, api, resource, role, intrinsics_resolver): #
382389
)
383390

384391
editor.add_state_machine_integration( # type: ignore[no-untyped-call]
385-
self.Path, # type: ignore[attr-defined]
392+
self.Path,
386393
self.Method,
387394
integration_uri,
388395
role.get_runtime_attr("arn"),
389396
request_template,
390397
condition=condition,
391398
)
392399

393-
# Note: Refactor and combine the section below with the Api eventsource for functions
394-
if self.Auth: # type: ignore[attr-defined]
395-
method_authorizer = self.Auth.get("Authorizer") # type: ignore[attr-defined]
396-
api_auth = api.get("Auth")
397-
api_auth = intrinsics_resolver.resolve_parameter_refs(api_auth)
398-
399-
if method_authorizer:
400-
api_authorizers = api_auth and api_auth.get("Authorizers")
401-
402-
if method_authorizer != "AWS_IAM":
403-
if method_authorizer != "NONE" and not api_authorizers:
404-
raise InvalidEventException(
405-
self.relative_id,
406-
"Unable to set Authorizer [{authorizer}] on API method [{method}] for path [{path}] "
407-
"because the related API does not define any Authorizers.".format(
408-
authorizer=method_authorizer, method=self.Method, path=self.Path # type: ignore[attr-defined]
409-
),
410-
)
411-
412-
if method_authorizer != "NONE" and not api_authorizers.get(method_authorizer):
413-
raise InvalidEventException(
414-
self.relative_id,
415-
"Unable to set Authorizer [{authorizer}] on API method [{method}] for path [{path}] "
416-
"because it wasn't defined in the API's Authorizers.".format(
417-
authorizer=method_authorizer, method=self.Method, path=self.Path # type: ignore[attr-defined]
418-
),
419-
)
420-
421-
if method_authorizer == "NONE":
422-
if not api_auth or not api_auth.get("DefaultAuthorizer"):
423-
raise InvalidEventException(
424-
self.relative_id,
425-
"Unable to set Authorizer on API method [{method}] for path [{path}] because 'NONE' "
426-
"is only a valid value when a DefaultAuthorizer on the API is specified.".format(
427-
method=self.Method, path=self.Path # type: ignore[attr-defined]
428-
),
429-
)
430-
431-
if self.Auth.get("AuthorizationScopes") and not isinstance(self.Auth.get("AuthorizationScopes"), list): # type: ignore[attr-defined]
432-
raise InvalidEventException(
433-
self.relative_id,
434-
"Unable to set Authorizer on API method [{method}] for path [{path}] because "
435-
"'AuthorizationScopes' must be a list of strings.".format(method=self.Method, path=self.Path), # type: ignore[attr-defined]
436-
)
400+
# self.Stage is not None as it is set in _get_permissions()
401+
# before calling this method.
402+
# TODO: refactor to remove this cast
403+
stage = cast(str, self.Stage)
437404

438-
apikey_required_setting = self.Auth.get("ApiKeyRequired") # type: ignore[attr-defined]
439-
apikey_required_setting_is_false = apikey_required_setting is not None and not apikey_required_setting
440-
if apikey_required_setting_is_false and (not api_auth or not api_auth.get("ApiKeyRequired")):
441-
raise InvalidEventException(
442-
self.relative_id,
443-
"Unable to set ApiKeyRequired [False] on API method [{method}] for path [{path}] "
444-
"because the related API does not specify any ApiKeyRequired.".format(
445-
method=self.Method, path=self.Path # type: ignore[attr-defined]
446-
),
447-
)
448-
449-
if method_authorizer or apikey_required_setting is not None:
450-
editor.add_auth_to_method(api=api, path=self.Path, method_name=self.Method, auth=self.Auth) # type: ignore[attr-defined, attr-defined, no-untyped-call]
451-
452-
if self.Auth.get("ResourcePolicy"): # type: ignore[attr-defined]
453-
resource_policy = self.Auth.get("ResourcePolicy") # type: ignore[attr-defined]
454-
editor.add_resource_policy(resource_policy=resource_policy, path=self.Path, stage=self.Stage) # type: ignore[attr-defined, attr-defined, no-untyped-call]
455-
if resource_policy.get("CustomStatements"):
456-
editor.add_custom_statements(resource_policy.get("CustomStatements")) # type: ignore[no-untyped-call]
405+
if self.Auth:
406+
PushApi.add_auth_to_swagger(
407+
self.Auth, api, api_id, self.relative_id, self.Method, self.Path, stage, editor, intrinsics_resolver
408+
)
457409

458410
api["DefinitionBody"] = editor.swagger
459411

samtranslator/swagger/swagger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ def set_path_default_apikey_required(self, path: str) -> None:
666666
if security != existing_security:
667667
method_definition["security"] = security
668668

669-
def add_auth_to_method(self, path, method_name, auth, api): # type: ignore[no-untyped-def]
669+
def add_auth_to_method(self, path: str, method_name: str, auth: Dict[str, Any], api: Dict[str, Any]) -> None:
670670
"""
671671
Adds auth settings for this path/method. Auth settings currently consist of Authorizers and ApiKeyRequired
672672
but this method will eventually include setting other auth settings such as Resource Policy, etc.
@@ -872,7 +872,7 @@ def add_models(self, models): # type: ignore[no-untyped-def]
872872

873873
self.definitions[model_name.lower()] = schema
874874

875-
def add_resource_policy(self, resource_policy, path, stage): # type: ignore[no-untyped-def]
875+
def add_resource_policy(self, resource_policy: Optional[Dict[str, Any]], path: str, stage: PassThrough) -> None:
876876
"""
877877
Add resource policy definition to Swagger.
878878

samtranslator/validator/value_validator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""A plug-able validator to help raise exception when some value is unexpected."""
2-
from typing import Generic, Optional, TypeVar
2+
from typing import Any, Dict, Generic, Optional, TypeVar, cast
33

44
from samtranslator.model.exceptions import (
55
ExpectedType,
@@ -72,8 +72,8 @@ def to_not_be_none(self, message: Optional[str] = "") -> T:
7272
#
7373
# alias methods:
7474
#
75-
def to_be_a_map(self, message: Optional[str] = "") -> T:
76-
return self.to_be_a(ExpectedType.MAP, message)
75+
def to_be_a_map(self, message: Optional[str] = "") -> Dict[str, Any]:
76+
return cast(Dict[str, Any], self.to_be_a(ExpectedType.MAP, message))
7777

7878
def to_be_a_list(self, message: Optional[str] = "") -> T:
7979
return self.to_be_a(ExpectedType.LIST, message)

tests/model/stepfunctions/test_api_event.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def setUp(self):
2222
self.state_machine.get_passthrough_resource_attributes.return_value = {}
2323

2424
def test_to_cloudformation_returns_role_resource(self):
25-
resources = self.api_event_source.to_cloudformation(resource=self.state_machine, explicit_api={})
25+
resources = self.api_event_source.to_cloudformation(
26+
resource=self.state_machine, explicit_api={}, api_id="MyRestApi"
27+
)
2628
self.assertEqual(len(resources), 1)
2729
self.assertEqual(resources[0].resource_type, "AWS::IAM::Role")
2830

@@ -63,7 +65,12 @@ def test_resources_to_link_with_explicit_api(self):
6365
self.api_event_source.RestApiId = {"Ref": "MyExplicitApi"}
6466
resources_to_link = self.api_event_source.resources_to_link(resources)
6567
self.assertEqual(
66-
resources_to_link, {"explicit_api": {"StageName": "Prod"}, "explicit_api_stage": {"suffix": "Prod"}}
68+
resources_to_link,
69+
{
70+
"explicit_api": {"StageName": "Prod"},
71+
"api_id": "MyExplicitApi",
72+
"explicit_api_stage": {"suffix": "Prod"},
73+
},
6774
)
6875

6976
def test_resources_to_link_with_undefined_explicit_api(self):
@@ -75,7 +82,9 @@ def test_resources_to_link_with_undefined_explicit_api(self):
7582
def test_resources_to_link_without_explicit_api(self):
7683
resources = {}
7784
resources_to_link = self.api_event_source.resources_to_link(resources)
78-
self.assertEqual(resources_to_link, {"explicit_api": None, "explicit_api_stage": {"suffix": "AllStages"}})
85+
self.assertEqual(
86+
resources_to_link, {"explicit_api": None, "api_id": None, "explicit_api_stage": {"suffix": "AllStages"}}
87+
)
7988

8089
def test_to_cloudformation_throws_when_no_resource(self):
8190
self.assertRaises(TypeError, self.api_event_source.to_cloudformation)

0 commit comments

Comments
 (0)