2121import json
2222
2323import numpy as np
24+ from six import with_metaclass
2425
2526from sagemaker .utils import DeferredError
2627
@@ -55,17 +56,44 @@ def ACCEPT(self):
5556 """The content types that are expected from the inference endpoint."""
5657
5758
58- class StringDeserializer (BaseDeserializer ):
59- """Deserialize data from an inference endpoint into a decoded string."""
59+ class SimpleBaseDeserializer (with_metaclass (abc .ABCMeta , BaseDeserializer )):
60+ """Abstract base class for creation of new deserializers.
61+
62+ This class extends the API of :class:~`sagemaker.deserializers.BaseDeserializer` with more
63+ user-friendly options for setting the ACCEPT content type header, in situations where it can be
64+ provided at init and freely updated.
65+ """
66+
67+ def __init__ (self , accept = "*/*" ):
68+ """Initialize a ``SimpleBaseDeserializer`` instance.
69+
70+ Args:
71+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
72+ is expected from the inference endpoint (default: "*/*").
73+ """
74+ super (SimpleBaseDeserializer , self ).__init__ ()
75+ self .accept = accept
76+
77+ @property
78+ def ACCEPT (self ):
79+ """The tuple of possible content types that are expected from the inference endpoint."""
80+ if isinstance (self .accept , str ):
81+ return (self .accept ,)
82+ return self .accept
6083
61- ACCEPT = ("application/json" ,)
6284
63- def __init__ (self , encoding = "UTF-8" ):
64- """Initialize the string encoding.
85+ class StringDeserializer (SimpleBaseDeserializer ):
86+ """Deserialize data from an inference endpoint into a decoded string."""
87+
88+ def __init__ (self , encoding = "UTF-8" , accept = "application/json" ):
89+ """Initialize a ``StringDeserializer`` instance.
6590
6691 Args:
6792 encoding (str): The string encoding to use (default: UTF-8).
93+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
94+ is expected from the inference endpoint (default: "application/json").
6895 """
96+ super (StringDeserializer , self ).__init__ (accept = accept )
6997 self .encoding = encoding
7098
7199 def deserialize (self , stream , content_type ):
@@ -84,11 +112,9 @@ def deserialize(self, stream, content_type):
84112 stream .close ()
85113
86114
87- class BytesDeserializer (BaseDeserializer ):
115+ class BytesDeserializer (SimpleBaseDeserializer ):
88116 """Deserialize a stream of bytes into a bytes object."""
89117
90- ACCEPT = ("*/*" ,)
91-
92118 def deserialize (self , stream , content_type ):
93119 """Read a stream of bytes returned from an inference endpoint.
94120
@@ -105,17 +131,23 @@ def deserialize(self, stream, content_type):
105131 stream .close ()
106132
107133
108- class CSVDeserializer (BaseDeserializer ):
109- """Deserialize a stream of bytes into a list of lists."""
134+ class CSVDeserializer (SimpleBaseDeserializer ):
135+ """Deserialize a stream of bytes into a list of lists.
110136
111- ACCEPT = ("text/csv" ,)
137+ Consider using :class:~`sagemaker.deserializers.NumpyDeserializer` or
138+ :class:~`sagemaker.deserializers.PandasDeserializer` instead, if you'd like to convert text/csv
139+ responses directly into other data types.
140+ """
112141
113- def __init__ (self , encoding = "utf-8" ):
114- """Initialize the string encoding .
142+ def __init__ (self , encoding = "utf-8" , accept = "text/csv" ):
143+ """Initialize a ``CSVDeserializer`` instance .
115144
116145 Args:
117146 encoding (str): The string encoding to use (default: "utf-8").
147+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
148+ is expected from the inference endpoint (default: "text/csv").
118149 """
150+ super (CSVDeserializer , self ).__init__ (accept = accept )
119151 self .encoding = encoding
120152
121153 def deserialize (self , stream , content_type ):
@@ -136,15 +168,13 @@ def deserialize(self, stream, content_type):
136168 stream .close ()
137169
138170
139- class StreamDeserializer (BaseDeserializer ):
140- """Returns the data and content-type received from an inference endpoint.
171+ class StreamDeserializer (SimpleBaseDeserializer ):
172+ """Directly return the data and content-type received from an inference endpoint.
141173
142174 It is the user's responsibility to close the data stream once they're done
143175 reading it.
144176 """
145177
146- ACCEPT = ("*/*" ,)
147-
148178 def deserialize (self , stream , content_type ):
149179 """Returns a stream of the response body and the MIME type of the data.
150180
@@ -158,20 +188,20 @@ def deserialize(self, stream, content_type):
158188 return stream , content_type
159189
160190
161- class NumpyDeserializer (BaseDeserializer ):
162- """Deserialize a stream of data in the .npy format."""
191+ class NumpyDeserializer (SimpleBaseDeserializer ):
192+ """Deserialize a stream of data in .npy or UTF-8 CSV/JSON format to a numpy array ."""
163193
164194 def __init__ (self , dtype = None , accept = "application/x-npy" , allow_pickle = True ):
165- """Initialize the dtype and allow_pickle arguments .
195+ """Initialize a ``NumpyDeserializer`` instance .
166196
167197 Args:
168198 dtype (str): The dtype of the data (default: None).
169- accept (str): The MIME type that is expected from the inference
170- endpoint (default: "application/x-npy").
199+ accept (union[ str, tuple[str]] ): The MIME type (or tuple of allowable MIME types) that
200+ is expected from the inference endpoint (default: "application/x-npy").
171201 allow_pickle (bool): Allow loading pickled object arrays (default: True).
172202 """
203+ super (NumpyDeserializer , self ).__init__ (accept = accept )
173204 self .dtype = dtype
174- self .accept = accept
175205 self .allow_pickle = allow_pickle
176206
177207 def deserialize (self , stream , content_type ):
@@ -198,21 +228,18 @@ def deserialize(self, stream, content_type):
198228
199229 raise ValueError ("%s cannot read content type %s." % (__class__ .__name__ , content_type ))
200230
201- @property
202- def ACCEPT (self ):
203- """The content types that are expected from the inference endpoint.
204-
205- To maintain backwards compatability with legacy images, the
206- NumpyDeserializer supports sending only one content type in the Accept
207- header.
208- """
209- return (self .accept ,)
210-
211231
212- class JSONDeserializer (BaseDeserializer ):
232+ class JSONDeserializer (SimpleBaseDeserializer ):
213233 """Deserialize JSON data from an inference endpoint into a Python object."""
214234
215- ACCEPT = ("application/json" ,)
235+ def __init__ (self , accept = "application/json" ):
236+ """Initialize a ``JSONDeserializer`` instance.
237+
238+ Args:
239+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
240+ is expected from the inference endpoint (default: "application/json").
241+ """
242+ super (JSONDeserializer , self ).__init__ (accept = accept )
216243
217244 def deserialize (self , stream , content_type ):
218245 """Deserialize JSON data from an inference endpoint into a Python object.
@@ -230,10 +257,17 @@ def deserialize(self, stream, content_type):
230257 stream .close ()
231258
232259
233- class PandasDeserializer (BaseDeserializer ):
260+ class PandasDeserializer (SimpleBaseDeserializer ):
234261 """Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""
235262
236- ACCEPT = ("text/csv" , "application/json" )
263+ def __init__ (self , accept = ("text/csv" , "application/json" )):
264+ """Initialize a ``PandasDeserializer`` instance.
265+
266+ Args:
267+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
268+ is expected from the inference endpoint (default: ("text/csv","application/json")).
269+ """
270+ super (PandasDeserializer , self ).__init__ (accept = accept )
237271
238272 def deserialize (self , stream , content_type ):
239273 """Deserialize CSV or JSON data from an inference endpoint into a pandas
@@ -258,10 +292,17 @@ def deserialize(self, stream, content_type):
258292 raise ValueError ("%s cannot read content type %s." % (__class__ .__name__ , content_type ))
259293
260294
261- class JSONLinesDeserializer (BaseDeserializer ):
295+ class JSONLinesDeserializer (SimpleBaseDeserializer ):
262296 """Deserialize JSON lines data from an inference endpoint."""
263297
264- ACCEPT = ("application/jsonlines" ,)
298+ def __init__ (self , accept = "application/jsonlines" ):
299+ """Initialize a ``JSONLinesDeserializer`` instance.
300+
301+ Args:
302+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
303+ is expected from the inference endpoint (default: ("text/csv","application/json")).
304+ """
305+ super (JSONLinesDeserializer , self ).__init__ (accept = accept )
265306
266307 def deserialize (self , stream , content_type ):
267308 """Deserialize JSON lines data from an inference endpoint.
0 commit comments