@@ -269,6 +269,192 @@ More information on how to create ``export_outputs`` can be found in `specifying
269269refer to TensorFlow's `Save and Restore <https://www.tensorflow.org/guide/saved_model >`_ documentation for other ways to control the
270270inference-time behavior of your SavedModels.
271271
272+ Providing Python scripts for pre/pos-processing
273+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
274+
275+ You can add your customized Python code to process your input and output data:
276+
277+ .. code ::
278+
279+ from sagemaker.tensorflow.serving import Model
280+
281+ model = Model(entry_point='inference.py',
282+ model_data='s3://mybucket/model.tar.gz',
283+ role='MySageMakerRole')
284+
285+ How to implement the pre- and/or post-processing handler(s)
286+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
287+
288+ Your entry point file should implement either a pair of ``input_handler ``
289+ and ``output_handler `` functions or a single ``handler `` function.
290+ Note that if ``handler `` function is implemented, ``input_handler ``
291+ and ``output_handler `` are ignored.
292+
293+ To implement pre- and/or post-processing handler(s), use the Context
294+ object that the Python service creates. The Context object is a namedtuple with the following attributes:
295+
296+ - ``model_name (string) ``: the name of the model to use for
297+ inference. For example, 'half-plus-three'
298+
299+ - ``model_version (string) ``: version of the model. For example, '5'
300+
301+ - ``method (string) ``: inference method. For example, 'predict',
302+ 'classify' or 'regress', for more information on methods, please see
303+ `Classify and Regress
304+ API <https://www.tensorflow.org/tfx/serving/api_rest#classify_and_regress_api> `__
305+ and `Predict
306+ API <https://www.tensorflow.org/tfx/serving/api_rest#predict_api> `__
307+
308+ - ``rest_uri (string) ``: the TFS REST uri generated by the Python
309+ service. For example,
310+ 'http://localhost:8501/v1/models/half_plus_three:predict'
311+
312+ - ``grpc_uri (string) ``: the GRPC port number generated by the Python
313+ service. For example, '9000'
314+
315+ - ``custom_attributes (string) ``: content of
316+ 'X-Amzn-SageMaker-Custom-Attributes' header from the original
317+ request. For example,
318+ 'tfs-model-name=half*plus*\ three,tfs-method=predict'
319+
320+ - ``request_content_type (string) ``: the original request content type,
321+ defaulted to 'application/json' if not provided
322+
323+ - ``accept_header (string) ``: the original request accept type,
324+ defaulted to 'application/json' if not provided
325+
326+ - ``content_length (int) ``: content length of the original request
327+
328+ The following code example implements ``input_handler `` and
329+ ``output_handler ``. By providing these, the Python service posts the
330+ request to the TFS REST URI with the data pre-processed by ``input_handler ``
331+ and passes the response to ``output_handler `` for post-processing.
332+
333+ .. code ::
334+
335+ import json
336+
337+ def input_handler(data, context):
338+ """ Pre-process request input before it is sent to TensorFlow Serving REST API
339+ Args:
340+ data (obj): the request data, in format of dict or string
341+ context (Context): an object containing request and configuration details
342+ Returns:
343+ (dict): a JSON-serializable dict that contains request body and headers
344+ """
345+ if context.request_content_type == 'application/json':
346+ # pass through json (assumes it's correctly formed)
347+ d = data.read().decode('utf-8')
348+ return d if len(d) else ''
349+
350+ if context.request_content_type == 'text/csv':
351+ # very simple csv handler
352+ return json.dumps({
353+ 'instances': [float(x) for x in data.read().decode('utf-8').split(',')]
354+ })
355+
356+ raise ValueError('{{"error": "unsupported content type {}"}}'.format(
357+ context.request_content_type or "unknown"))
358+
359+
360+ def output_handler(data, context):
361+ """Post-process TensorFlow Serving output before it is returned to the client.
362+ Args:
363+ data (obj): the TensorFlow serving response
364+ context (Context): an object containing request and configuration details
365+ Returns:
366+ (bytes, string): data to return to client, response content type
367+ """
368+ if data.status_code != 200:
369+ raise ValueError(data.content.decode('utf-8'))
370+
371+ response_content_type = context.accept_header
372+ prediction = data.content
373+ return prediction, response_content_type
374+
375+ You might want to have complete control over the request.
376+ For example, you might want to make a TFS request (REST or GRPC) to the first model,
377+ inspect the results, and then make a request to a second model. In this case, implement
378+ the ``handler `` method instead of the ``input_handler `` and ``output_handler `` methods, as demonstrated
379+ in the following code:
380+
381+ .. code ::
382+
383+ import json
384+ import requests
385+
386+
387+ def handler(data, context):
388+ """Handle request.
389+ Args:
390+ data (obj): the request data
391+ context (Context): an object containing request and configuration details
392+ Returns:
393+ (bytes, string): data to return to client, (optional) response content type
394+ """
395+ processed_input = _process_input(data, context)
396+ response = requests.post(context.rest_uri, data=processed_input)
397+ return _process_output(response, context)
398+
399+
400+ def _process_input(data, context):
401+ if context.request_content_type == 'application/json':
402+ # pass through json (assumes it's correctly formed)
403+ d = data.read().decode('utf-8')
404+ return d if len(d) else ''
405+
406+ if context.request_content_type == 'text/csv':
407+ # very simple csv handler
408+ return json.dumps({
409+ 'instances': [float(x) for x in data.read().decode('utf-8').split(',')]
410+ })
411+
412+ raise ValueError('{{"error": "unsupported content type {}"}}'.format(
413+ context.request_content_type or "unknown"))
414+
415+
416+ def _process_output(data, context):
417+ if data.status_code != 200:
418+ raise ValueError(data.content.decode('utf-8'))
419+
420+ response_content_type = context.accept_header
421+ prediction = data.content
422+ return prediction, response_content_type
423+
424+ You can also bring in external dependencies to help with your data
425+ processing. There are 2 ways to do this:
426+
427+ 1. If you included ``requirements.txt `` in your ``source_dir `` or in
428+ your dependencies, the container installs the Python dependencies at runtime using ``pip install -r ``:
429+
430+ .. code ::
431+
432+ from sagemaker.tensorflow.serving import Model
433+
434+ model = Model(entry_point='inference.py',
435+ dependencies=['requirements.txt'],
436+ model_data='s3://mybucket/model.tar.gz',
437+ role='MySageMakerRole')
438+
439+
440+ 2. If you are working in a network-isolation situation or if you don't
441+ want to install dependencies at runtime every time your endpoint starts or a batch
442+ transform job runs, you might want to put
443+ pre-downloaded dependencies under a ``lib `` directory and this
444+ directory as dependency. The container adds the modules to the Python
445+ path. Note that if both ``lib `` and ``requirements.txt ``
446+ are present in the model archive, the ``requirements.txt `` is ignored:
447+
448+ .. code ::
449+
450+ from sagemaker.tensorflow.serving import Model
451+
452+ model = Model(entry_point='inference.py',
453+ dependencies=['/path/to/folder/named/lib'],
454+ model_data='s3://mybucket/model.tar.gz',
455+ role='MySageMakerRole')
456+
457+
272458 Deploying more than one model to your Endpoint
273459~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
274460
0 commit comments