Skip to content

Commit d1d209f

Browse files
jaymccontobywf
authored andcommitted
Report progress (#64)
1 parent 1767fa9 commit d1d209f

File tree

11 files changed

+347
-43
lines changed

11 files changed

+347
-43
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import json
2+
import logging
3+
from typing import Optional
4+
from uuid import uuid4
5+
6+
# boto3 doesn't have stub files
7+
from boto3 import Session # type: ignore
8+
9+
from .interface import BaseResourceModel, HandlerErrorCode, OperationStatus
10+
from .utils import KitchenSinkEncoder
11+
12+
LOG = logging.getLogger(__name__)
13+
14+
15+
def report_progress( # pylint: disable=too-many-arguments
16+
session: Session,
17+
bearer_token: str,
18+
error_code: Optional[HandlerErrorCode],
19+
operation_status: OperationStatus,
20+
current_operation_status: Optional[OperationStatus],
21+
resource_model: Optional[BaseResourceModel],
22+
status_message: str,
23+
) -> None:
24+
client = session.client("cloudformation")
25+
request = {
26+
"BearerToken": bearer_token,
27+
"OperationStatus": operation_status.name,
28+
"StatusMessage": status_message,
29+
"ClientRequestToken": str(uuid4()),
30+
}
31+
if resource_model:
32+
request["ResourceModel"] = json.dumps(
33+
resource_model._serialize(), # pylint: disable=protected-access
34+
cls=KitchenSinkEncoder,
35+
)
36+
if error_code:
37+
request["ErrorCode"] = error_code.name
38+
if current_operation_status:
39+
request["CurrentOperationStatus"] = current_operation_status.name
40+
response = client.record_handler_progress(**request)
41+
LOG.info(
42+
"Record Handler Progress with Request Id %s and Request: {%s}",
43+
response["ResponseMetadata"]["RequestId"],
44+
request,
45+
)

src/cloudformation_cli_python_lib/interface.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class ProgressEvent:
7575
status: OperationStatus
7676
errorCode: Optional[HandlerErrorCode] = None
7777
message: str = ""
78-
callbackContext: Optional[Mapping[str, Any]] = None
78+
callbackContext: Optional[MutableMapping[str, Any]] = None
7979
callbackDelaySeconds: int = 0
8080
resourceModel: Optional[BaseResourceModel] = None
8181
resourceModels: Optional[List[BaseResourceModel]] = None
@@ -90,9 +90,21 @@ def _serialize(
9090
# mutate to what's expected in the response
9191
if to_response:
9292
ser["bearerToken"] = bearer_token
93-
ser["operationStatus"] = ser.pop("status")
94-
if ser["callbackDelaySeconds"] == 0:
95-
del ser["callbackDelaySeconds"]
93+
ser["operationStatus"] = ser.pop("status").name
94+
if self.resourceModel:
95+
# pylint: disable=protected-access
96+
ser["resourceModel"] = self.resourceModel._serialize()
97+
if self.resourceModels:
98+
ser["resourceModels"] = [
99+
# pylint: disable=protected-access
100+
model._serialize()
101+
for model in self.resourceModels
102+
]
103+
del ser["callbackDelaySeconds"]
104+
if "callbackContext" in ser:
105+
del ser["callbackContext"]
106+
if self.errorCode:
107+
ser["errorCode"] = self.errorCode.name
96108
return ser
97109

98110
@classmethod

src/cloudformation_cli_python_lib/log_delivery.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import time
3-
from typing import Any, Mapping
3+
from typing import Any, Mapping, Optional
44

55
# boto3 doesn't have stub files
66
import boto3 # type: ignore
@@ -28,6 +28,13 @@ def __init__(
2828
self.client = boto3.client("logs", **creds)
2929
self.sequence_token = ""
3030

31+
@classmethod
32+
def _get_existing_logger(cls) -> Optional["ProviderLogHandler"]:
33+
for handler in logging.getLogger().handlers:
34+
if isinstance(handler, cls):
35+
return handler
36+
return None
37+
3138
@classmethod
3239
def setup(cls, event_data: Mapping[str, Any]) -> None:
3340
try:
@@ -46,12 +53,23 @@ def setup(cls, event_data: Mapping[str, Any]) -> None:
4653
except KeyError:
4754
stream_name = f'{event_data["awsAccountId"]}-{event_data["region"]}'
4855

49-
# filter provider messages from platform
50-
ProviderFilter.PROVIDER = event_data["resourceType"].replace("::", "_").lower()
51-
logging.getLogger().handlers[0].addFilter(ProviderFilter())
52-
53-
# add log handler to root, so that provider gets plugin logs too
56+
log_handler = cls._get_existing_logger()
5457
if log_creds and log_group:
58+
if log_handler:
59+
# This is a re-used lambda container, log handler is already setup, so
60+
# we just refresh the client with new creds
61+
log_handler.client = boto3.client(
62+
"logs",
63+
aws_access_key_id=log_creds["accessKeyId"],
64+
aws_secret_access_key=log_creds["secretAccessKey"],
65+
aws_session_token=log_creds["sessionToken"],
66+
)
67+
return
68+
# filter provider messages from platform
69+
ProviderFilter.PROVIDER = (
70+
event_data["resourceType"].replace("::", "_").lower()
71+
)
72+
logging.getLogger().handlers[0].addFilter(ProviderFilter())
5573
log_handler = cls(
5674
group=log_group,
5775
stream=stream_name,
@@ -61,6 +79,7 @@ def setup(cls, event_data: Mapping[str, Any]) -> None:
6179
"aws_session_token": log_creds["sessionToken"],
6280
},
6381
)
82+
# add log handler to root, so that provider gets plugin logs too
6483
logging.getLogger().addHandler(log_handler)
6584

6685
def _create_log_group(self) -> None:

src/cloudformation_cli_python_lib/resource.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import boto3 # type: ignore
99

1010
from .boto3_proxy import SessionProxy, _get_boto_session
11+
from .callback import report_progress
1112
from .exceptions import InternalFailure, InvalidRequest, _HandlerError
1213
from .interface import (
1314
Action,
@@ -211,9 +212,40 @@ def __call__(
211212
ProviderLogHandler.setup(event_data)
212213
parsed = self._parse_request(event_data)
213214
caller_sess, platform_sess, request, action, callback, event = parsed
215+
# Acknowledge the task for first time invocation
216+
if not event.requestContext:
217+
report_progress(
218+
platform_sess,
219+
event.bearerToken,
220+
None,
221+
OperationStatus.IN_PROGRESS,
222+
OperationStatus.PENDING,
223+
None,
224+
"",
225+
)
226+
else:
227+
# If this invocation was triggered by a 're-invoke' CloudWatch Event,
228+
# clean it up
229+
CloudWatchScheduler(platform_sess).cleanup_cloudwatch_events(
230+
event.requestContext.get("cloudWatchEventsRuleName", ""),
231+
event.requestContext.get("cloudWatchEventsTargetId", ""),
232+
)
214233
invoke = True
215234
while invoke:
216235
progress = self._invoke_handler(caller_sess, request, action, callback)
236+
if progress.callbackContext:
237+
callback = progress.callbackContext
238+
event.requestContext["callbackContext"] = callback
239+
if event.action in MUTATING_ACTIONS:
240+
report_progress(
241+
platform_sess,
242+
event.bearerToken,
243+
progress.errorCode,
244+
progress.status,
245+
OperationStatus.IN_PROGRESS,
246+
progress.resourceModel,
247+
progress.message,
248+
)
217249
invoke = self.schedule_reinvocation(
218250
event, progress, context, platform_sess
219251
)

src/cloudformation_cli_python_lib/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@ def _min_to_cron(minutes: int) -> str:
5555
schedule_time = datetime.now() + timedelta(minutes=minutes)
5656
# add another minute, as per java implementation
5757
schedule_time = schedule_time + timedelta(minutes=1)
58-
return schedule_time.strftime("cron('%M %H %d %m ? %Y')")
58+
return schedule_time.strftime("cron(%M %H %d %m ? %Y)")

src/cloudformation_cli_python_lib/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class RequestData:
4343
logicalResourceId: str
4444
resourceProperties: Mapping[str, Any]
4545
systemTags: Mapping[str, Any]
46-
stackTags: Mapping[str, Any]
46+
stackTags: Optional[Mapping[str, Any]] = None
4747
callerCredentials: Optional[Credentials] = field(default=None)
4848
providerCredentials: Optional[Credentials] = field(default=None)
4949
previousResourceProperties: Optional[Mapping[str, Any]] = None

tests/lib/callback_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# pylint: disable=redefined-outer-name,protected-access
2+
from unittest.mock import Mock, patch
3+
from uuid import uuid4
4+
5+
import boto3
6+
from cloudformation_cli_python_lib.callback import report_progress
7+
from cloudformation_cli_python_lib.interface import (
8+
BaseResourceModel,
9+
HandlerErrorCode,
10+
OperationStatus,
11+
)
12+
13+
14+
class MockSession:
15+
def __init__(self):
16+
self._cfn = Mock(boto3.client("cloudformation"), autospec=True)
17+
self._cfn.record_handler_progress.return_value = {
18+
"ResponseMetadata": {"RequestId": "mock_request"}
19+
}
20+
21+
def client(self, _name):
22+
return self._cfn
23+
24+
25+
def test_report_progress_minimal():
26+
session = MockSession()
27+
uuid = uuid4()
28+
with patch("cloudformation_cli_python_lib.callback.uuid4", return_value=uuid):
29+
report_progress(
30+
session, "123", None, OperationStatus.IN_PROGRESS, None, None, ""
31+
)
32+
session._cfn.record_handler_progress.assert_called_once_with(
33+
BearerToken="123",
34+
OperationStatus="IN_PROGRESS",
35+
StatusMessage="",
36+
ClientRequestToken=str(uuid),
37+
)
38+
39+
40+
def test_report_progress_full():
41+
session = MockSession()
42+
uuid = uuid4()
43+
with patch("cloudformation_cli_python_lib.callback.uuid4", return_value=uuid):
44+
report_progress(
45+
session,
46+
"123",
47+
HandlerErrorCode.InternalFailure,
48+
OperationStatus.FAILED,
49+
OperationStatus.IN_PROGRESS,
50+
BaseResourceModel(),
51+
"test message",
52+
)
53+
session._cfn.record_handler_progress.assert_called_once_with(
54+
BearerToken="123",
55+
OperationStatus="FAILED",
56+
CurrentOperationStatus="IN_PROGRESS",
57+
StatusMessage="test message",
58+
ResourceModel="{}",
59+
ErrorCode="InternalFailure",
60+
ClientRequestToken=str(uuid),
61+
)

tests/lib/interface_test.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
# pylint: disable=protected-access,redefined-outer-name
1+
# pylint: disable=protected-access,redefined-outer-name,abstract-method
22
import json
3+
from dataclasses import dataclass
34
from string import ascii_letters
45

56
import boto3
@@ -26,6 +27,12 @@ def client():
2627
)
2728

2829

30+
@dataclass
31+
class TestModel(BaseResourceModel):
32+
somekey: str
33+
someotherkey: str
34+
35+
2936
def test_base_resource_model__deserialize():
3037
with pytest.raises(NotImplementedError):
3138
BaseResourceModel()._deserialize({})
@@ -52,16 +59,48 @@ def test_progress_event_failed_is_json_serializable(error_code, message):
5259

5360

5461
@given(s.text(ascii_letters), s.text(ascii_letters))
55-
def test_progress_event_serialize_to_response(message, bearer_token):
62+
def test_progress_event_serialize_to_response_with_context(message, bearer_token):
63+
event = ProgressEvent(
64+
status=OperationStatus.SUCCESS, message=message, callbackContext={"a": "b"}
65+
)
66+
67+
assert event._serialize(to_response=True, bearer_token=bearer_token) == {
68+
"operationStatus": OperationStatus.SUCCESS.name, # pylint: disable=no-member
69+
"message": message,
70+
"bearerToken": bearer_token,
71+
}
72+
73+
74+
@given(s.text(ascii_letters), s.text(ascii_letters))
75+
def test_progress_event_serialize_to_response_with_model(message, bearer_token):
76+
model = TestModel("a", "b")
77+
event = ProgressEvent(
78+
status=OperationStatus.SUCCESS, message=message, resourceModel=model
79+
)
80+
81+
assert event._serialize(to_response=True, bearer_token=bearer_token) == {
82+
"operationStatus": OperationStatus.SUCCESS.name, # pylint: disable=no-member
83+
"message": message,
84+
"bearerToken": bearer_token,
85+
"resourceModel": {"somekey": "a", "someotherkey": "b"},
86+
}
87+
88+
89+
@given(s.text(ascii_letters), s.text(ascii_letters))
90+
def test_progress_event_serialize_to_response_with_models(message, bearer_token):
91+
models = [TestModel("a", "b"), TestModel("c", "d")]
5692
event = ProgressEvent(
57-
status=OperationStatus.SUCCESS, message=message, callbackDelaySeconds=1
93+
status=OperationStatus.SUCCESS, message=message, resourceModels=models
5894
)
5995

6096
assert event._serialize(to_response=True, bearer_token=bearer_token) == {
61-
"operationStatus": OperationStatus.SUCCESS.value,
97+
"operationStatus": OperationStatus.SUCCESS.name, # pylint: disable=no-member
6298
"message": message,
6399
"bearerToken": bearer_token,
64-
"callbackDelaySeconds": 1,
100+
"resourceModels": [
101+
{"somekey": "a", "someotherkey": "b"},
102+
{"somekey": "c", "someotherkey": "d"},
103+
],
65104
}
66105

67106

0 commit comments

Comments
 (0)