15
15
import io
16
16
import json
17
17
import logging
18
-
19
18
import pytest
20
19
from mock import Mock
21
-
22
20
from sagemaker .tensorflow import TensorFlow
23
21
from sagemaker .tensorflow .predictor import csv_serializer
24
22
from sagemaker .tensorflow .serving import Model , Predictor
@@ -167,12 +165,12 @@ def test_predictor_classify(sagemaker_session):
167
165
mock_response (json .dumps (CLASSIFY_RESPONSE ).encode ('utf-8' ), sagemaker_session )
168
166
result = predictor .classify (CLASSIFY_INPUT )
169
167
170
- assert_invoked (sagemaker_session ,
171
- EndpointName = 'endpoint' ,
172
- ContentType = JSON_CONTENT_TYPE ,
173
- Accept = JSON_CONTENT_TYPE ,
174
- CustomAttributes = 'tfs-method=classify' ,
175
- Body = json .dumps (CLASSIFY_INPUT ))
168
+ assert_invoked_with_body_dict (sagemaker_session ,
169
+ EndpointName = 'endpoint' ,
170
+ ContentType = JSON_CONTENT_TYPE ,
171
+ Accept = JSON_CONTENT_TYPE ,
172
+ CustomAttributes = 'tfs-method=classify' ,
173
+ Body = json .dumps (CLASSIFY_INPUT ))
176
174
177
175
assert CLASSIFY_RESPONSE == result
178
176
@@ -183,12 +181,12 @@ def test_predictor_regress(sagemaker_session):
183
181
mock_response (json .dumps (REGRESS_RESPONSE ).encode ('utf-8' ), sagemaker_session )
184
182
result = predictor .regress (REGRESS_INPUT )
185
183
186
- assert_invoked (sagemaker_session ,
187
- EndpointName = 'endpoint' ,
188
- ContentType = JSON_CONTENT_TYPE ,
189
- Accept = JSON_CONTENT_TYPE ,
190
- CustomAttributes = 'tfs-method=regress,tfs-model-name=model,tfs-model-version=123' ,
191
- Body = json .dumps (REGRESS_INPUT ))
184
+ assert_invoked_with_body_dict (sagemaker_session ,
185
+ EndpointName = 'endpoint' ,
186
+ ContentType = JSON_CONTENT_TYPE ,
187
+ Accept = JSON_CONTENT_TYPE ,
188
+ CustomAttributes = 'tfs-method=regress,tfs-model-name=model,tfs-model-version=123' ,
189
+ Body = json .dumps (REGRESS_INPUT ))
192
190
193
191
assert REGRESS_RESPONSE == result
194
192
@@ -208,12 +206,23 @@ def test_predictor_classify_bad_content_type():
208
206
209
207
210
208
def assert_invoked (sagemaker_session , ** kwargs ):
209
+ sagemaker_session .sagemaker_runtime_client .invoke_endpoint .assert_called_once_with (** kwargs )
210
+
211
+
212
+ def assert_invoked_with_body_dict (sagemaker_session , ** kwargs ):
211
213
call = sagemaker_session .sagemaker_runtime_client .invoke_endpoint .call_args
212
214
cargs , ckwargs = call
213
215
assert not cargs
214
216
assert len (kwargs ) == len (ckwargs )
215
217
for k in ckwargs :
216
- assert kwargs [k ] == ckwargs [k ]
218
+ if k != 'Body' :
219
+ assert kwargs [k ] == ckwargs [k ]
220
+ else :
221
+ actual_body = json .loads (ckwargs [k ])
222
+ expected_body = json .loads (kwargs [k ])
223
+ assert len (actual_body ) == len (expected_body )
224
+ for k2 in actual_body :
225
+ assert actual_body [k2 ] == expected_body [k2 ]
217
226
218
227
219
228
def mock_response (expected_response , sagemaker_session , content_type = JSON_CONTENT_TYPE ):
0 commit comments