1818import google .protobuf .json_format as json_format
1919from google .protobuf .message import DecodeError
2020from protobuf_to_dict import protobuf_to_dict
21- from tensorflow .core .framework import tensor_pb2 # pylint: disable=no-name-in-module
22- from tensorflow .python .framework import tensor_util # pylint: disable=no-name-in-module
2321
2422from sagemaker .content_types import CONTENT_TYPE_JSON , CONTENT_TYPE_OCTET_STREAM , CONTENT_TYPE_CSV
2523from sagemaker .predictor import json_serializer , csv_serializer
26- from tensorflow_serving .apis import predict_pb2 , classification_pb2 , inference_pb2 , regression_pb2
2724
28- _POSSIBLE_RESPONSES = [
29- predict_pb2 .PredictResponse ,
30- classification_pb2 .ClassificationResponse ,
31- inference_pb2 .MultiInferenceResponse ,
32- regression_pb2 .RegressionResponse ,
33- tensor_pb2 .TensorProto ,
34- ]
25+
26+ def _possible_responses ():
27+ """
28+ Returns: Possible available request types.
29+ """
30+ from tensorflow .core .framework import tensor_pb2 # pylint: disable=no-name-in-module
31+ from tensorflow_serving .apis import (
32+ predict_pb2 ,
33+ classification_pb2 ,
34+ inference_pb2 ,
35+ regression_pb2 ,
36+ )
37+
38+ return [
39+ predict_pb2 .PredictResponse ,
40+ classification_pb2 .ClassificationResponse ,
41+ inference_pb2 .MultiInferenceResponse ,
42+ regression_pb2 .RegressionResponse ,
43+ tensor_pb2 .TensorProto ,
44+ ]
45+
3546
3647REGRESSION_REQUEST = "RegressionRequest"
3748MULTI_INFERENCE_REQUEST = "MultiInferenceRequest"
@@ -88,7 +99,7 @@ def __call__(self, stream, content_type):
8899 finally :
89100 stream .close ()
90101
91- for possible_response in _POSSIBLE_RESPONSES :
102+ for possible_response in _possible_responses () :
92103 try :
93104 response = possible_response ()
94105 response .ParseFromString (data )
@@ -114,6 +125,9 @@ def __call__(self, data):
114125 Args:
115126 data:
116127 """
128+
129+ from tensorflow .core .framework import tensor_pb2 # pylint: disable=no-name-in-module
130+
117131 if isinstance (data , tensor_pb2 .TensorProto ):
118132 return json_format .MessageToJson (data )
119133 return json_serializer (data )
@@ -139,7 +153,7 @@ def __call__(self, stream, content_type):
139153 finally :
140154 stream .close ()
141155
142- for possible_response in _POSSIBLE_RESPONSES :
156+ for possible_response in _possible_responses () :
143157 try :
144158 return protobuf_to_dict (json_format .Parse (data , possible_response ()))
145159 except (UnicodeDecodeError , DecodeError , json_format .ParseError ):
@@ -164,6 +178,10 @@ def __call__(self, data):
164178 data:
165179 """
166180 to_serialize = data
181+
182+ from tensorflow .core .framework import tensor_pb2 # pylint: disable=no-name-in-module
183+ from tensorflow .python .framework import tensor_util # pylint: disable=no-name-in-module
184+
167185 if isinstance (data , tensor_pb2 .TensorProto ):
168186 to_serialize = tensor_util .MakeNdarray (data )
169187 return csv_serializer (to_serialize )
0 commit comments