Skip to content

Commit aedebb9

Browse files
jesterhazyJonathan Esterhazy
authored andcommitted
fix flaky tfs test
1 parent dd2529a commit aedebb9

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

tests/unit/test_tfs.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
import io
1616
import json
1717
import logging
18-
1918
import pytest
2019
from mock import Mock
21-
2220
from sagemaker.tensorflow import TensorFlow
2321
from sagemaker.tensorflow.predictor import csv_serializer
2422
from sagemaker.tensorflow.serving import Model, Predictor
@@ -167,12 +165,12 @@ def test_predictor_classify(sagemaker_session):
167165
mock_response(json.dumps(CLASSIFY_RESPONSE).encode('utf-8'), sagemaker_session)
168166
result = predictor.classify(CLASSIFY_INPUT)
169167

170-
assert_invoked(sagemaker_session,
171-
EndpointName='endpoint',
172-
ContentType=JSON_CONTENT_TYPE,
173-
Accept=JSON_CONTENT_TYPE,
174-
CustomAttributes='tfs-method=classify',
175-
Body=json.dumps(CLASSIFY_INPUT))
168+
assert_invoked_with_body_dict(sagemaker_session,
169+
EndpointName='endpoint',
170+
ContentType=JSON_CONTENT_TYPE,
171+
Accept=JSON_CONTENT_TYPE,
172+
CustomAttributes='tfs-method=classify',
173+
Body=json.dumps(CLASSIFY_INPUT))
176174

177175
assert CLASSIFY_RESPONSE == result
178176

@@ -183,12 +181,12 @@ def test_predictor_regress(sagemaker_session):
183181
mock_response(json.dumps(REGRESS_RESPONSE).encode('utf-8'), sagemaker_session)
184182
result = predictor.regress(REGRESS_INPUT)
185183

186-
assert_invoked(sagemaker_session,
187-
EndpointName='endpoint',
188-
ContentType=JSON_CONTENT_TYPE,
189-
Accept=JSON_CONTENT_TYPE,
190-
CustomAttributes='tfs-method=regress,tfs-model-name=model,tfs-model-version=123',
191-
Body=json.dumps(REGRESS_INPUT))
184+
assert_invoked_with_body_dict(sagemaker_session,
185+
EndpointName='endpoint',
186+
ContentType=JSON_CONTENT_TYPE,
187+
Accept=JSON_CONTENT_TYPE,
188+
CustomAttributes='tfs-method=regress,tfs-model-name=model,tfs-model-version=123',
189+
Body=json.dumps(REGRESS_INPUT))
192190

193191
assert REGRESS_RESPONSE == result
194192

@@ -208,12 +206,23 @@ def test_predictor_classify_bad_content_type():
208206

209207

210208
def assert_invoked(sagemaker_session, **kwargs):
209+
sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with(**kwargs)
210+
211+
212+
def assert_invoked_with_body_dict(sagemaker_session, **kwargs):
211213
call = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
212214
cargs, ckwargs = call
213215
assert not cargs
214216
assert len(kwargs) == len(ckwargs)
215217
for k in ckwargs:
216-
assert kwargs[k] == ckwargs[k]
218+
if k != 'Body':
219+
assert kwargs[k] == ckwargs[k]
220+
else:
221+
actual_body = json.loads(ckwargs[k])
222+
expected_body = json.loads(kwargs[k])
223+
assert len(actual_body) == len(expected_body)
224+
for k2 in actual_body:
225+
assert actual_body[k2] == expected_body[k2]
217226

218227

219228
def mock_response(expected_response, sagemaker_session, content_type=JSON_CONTENT_TYPE):

0 commit comments

Comments
 (0)