Skip to content

Commit 7198003

Browse files
authored
doc: document model.tar.gz structure for MXNet and PyTorch (#1446)
1 parent 1c8fc7e commit 7198003

File tree

2 files changed

+118
-8
lines changed

2 files changed

+118
-8
lines changed

doc/using_mxnet.rst

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,9 @@ Inference arrays or lists are serialized and sent to the MXNet model server by a
420420
``predict`` returns the result of inference against your model.
421421
By default, the inference result is either a Python list or dictionary.
422422

423+
Elastic Inference
424+
=================
425+
423426
MXNet on Amazon SageMaker has support for `Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html>`_, which allows for inference acceleration to a hosted endpoint for a fraction of the cost of using a full GPU instance.
424427
In order to attach an Elastic Inference accelerator to your endpoint provide the accelerator type to ``accelerator_type`` to your ``deploy`` call.
425428

@@ -429,6 +432,61 @@ In order to attach an Elastic Inference accelerator to your endpoint provide the
429432
initial_instance_count=1,
430433
accelerator_type='ml.eia1.medium')
431434
435+
Model Directory Structure
436+
=========================
437+
438+
In general, if you use the same version of MXNet for both training and inference with the SageMaker Python SDK,
439+
the SDK should take care of ensuring that the contents of your ``model.tar.gz`` file are organized correctly.
440+
441+
For versions 1.4 and higher
442+
---------------------------
443+
444+
For MXNet versions 1.4 and higher, the contents of ``model.tar.gz`` should be organized as follows:
445+
446+
- Model files in the top-level directory
447+
- Inference script (and any other source files) in a directory named ``code/`` (for more about the inference script, see `The SageMaker MXNet Model Server <#the-sagemaker-mxnet-model-server>`_)
448+
- Optional requirements file located at ``code/requirements.txt`` (for more about requirements files, see `Use third-party libraries <#use-third-party-libraries>`_)
449+
450+
For example:
451+
452+
.. code::
453+
454+
model.tar.gz/
455+
|- model-symbol.json
456+
|- model-shapes.json
457+
|- model-0000.params
458+
|- code/
459+
|- inference.py
460+
|- requirements.txt # only for versions 1.6.0 and higher
461+
462+
In this example, ``model-symbol.json``, ``model-shapes.json``, and ``model-0000.params`` are the model files saved from training,
463+
``inference.py`` is the inference script, and ``requirements.txt`` is a requirements file.
464+
465+
The ``MXNet`` and ``MXNetModel`` classes repack ``model.tar.gz`` to include the inference script (and related files),
466+
as long as the ``framework_version`` is set to 1.4 or higher.
467+
468+
For versions 1.3 and lower
469+
--------------------------
470+
471+
For MXNet versions 1.3 and lower, ``model.tar.gz`` should contain only the model files,
472+
while your inference script and optional requirements file are packed in a separate tarball, named ``sourcedir.tar.gz`` by default.
473+
474+
For example:
475+
476+
.. code::
477+
478+
model.tar.gz/
479+
|- model-symbol.json
480+
|- model-shapes.json
481+
|- model-0000.params
482+
483+
sourcedir.tar.gz/
484+
|- script.py
485+
|- requirements.txt # only for versions 0.12.1-1.3.0
486+
487+
In this example, ``model-symbol.json``, ``model-shapes.json``, and ``model-0000.params`` are the model files saved from training,
488+
``script.py`` is the inference script, and ``requirements.txt`` is a requirements file.
489+
432490
The SageMaker MXNet Model Server
433491
================================
434492

@@ -512,7 +570,7 @@ Defining how to handle these requests can be done in one of two ways:
512570
- writing your own ``transform_fn`` for handling input processing, prediction, and output processing
513571

514572
Use ``input_fn``, ``predict_fn``, and ``output_fn``
515-
---------------------------------------------------
573+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
516574

517575
The SageMaker MXNet model server breaks request handling into three steps:
518576

@@ -565,7 +623,7 @@ In the following sections, we describe the default implementations of ``input_fn
565623
We describe the input arguments and expected return types of each, so you can define your own implementations.
566624

567625
Process Model Input
568-
^^^^^^^^^^^^^^^^^^^
626+
~~~~~~~~~~~~~~~~~~~
569627

570628
When an ``InvokeEndpoint`` operation is made against an endpoint running an MXNet model server, the model server receives two pieces of information:
571629

@@ -607,7 +665,7 @@ If you provide your own implementation of input_fn, you should abide by the ``in
607665
pass
608666
609667
Predict from a Deployed Model
610-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
668+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
611669

612670
After the inference request has been deserialized by ``input_fn``, the MXNet model server invokes ``predict_fn``.
613671
As with the other functions, you can define your own ``predict_fn`` or use the model server's default.
@@ -640,7 +698,7 @@ If you implement your own prediction function, you should take care to ensure th
640698
``output_fn``, this should be an ``NDArrayIter``.
641699

642700
Process Model Output
643-
^^^^^^^^^^^^^^^^^^^^
701+
~~~~~~~~~~~~~~~~~~~~
644702

645703
After invoking ``predict_fn``, the model server invokes ``output_fn``, passing in the return value from ``predict_fn`` and the ``InvokeEndpoint`` requested response content type.
646704

@@ -656,7 +714,7 @@ The function should return an array of bytes serialized to the expected content
656714
The default implementation expects ``prediction`` to be an ``NDArray`` and can serialize the result to either JSON or CSV. It accepts response content types of "application/json" and "text/csv".
657715

658716
Use ``transform_fn``
659-
--------------------
717+
^^^^^^^^^^^^^^^^^^^^
660718

661719
If you would rather not structure your code around the three methods described above, you can instead define your own ``transform_fn`` to handle inference requests.
662720
An error is thrown if a ``transform_fn`` is present in conjunction with any ``input_fn``, ``predict_fn``, and/or ``output_fn``.

doc/using_pytorch.rst

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ Using PyTorch with the SageMaker Python SDK
44

55
With PyTorch Estimators and Models, you can train and host PyTorch models on Amazon SageMaker.
66

7-
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``, ``1.3.1``.
7+
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``, ``1.3.1``, ``1.4.0``.
88

99
Supported versions of PyTorch for Elastic Inference: ``1.3.1``.
1010

11-
We recommend that you use the latest supported version, because that's where we focus most of our development efforts.
11+
We recommend that you use the latest supported version because that's where we focus our development efforts.
1212

1313
You can visit the PyTorch repository at https://github.com/pytorch/pytorch.
1414

@@ -264,6 +264,9 @@ You use the SageMaker PyTorch model server to host your PyTorch model when you c
264264
Estimator. The model server runs inside a SageMaker Endpoint, which your call to ``deploy`` creates.
265265
You can access the name of the Endpoint by the ``name`` property on the returned ``Predictor``.
266266

267+
Elastic Inference
268+
=================
269+
267270
PyTorch on Amazon SageMaker has support for `Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html>`_, which allows for inference acceleration to a hosted endpoint for a fraction of the cost of using a full GPU instance.
268271
In order to attach an Elastic Inference accelerator to your endpoint provide the accelerator type to ``accelerator_type`` to your ``deploy`` call.
269272

@@ -273,6 +276,55 @@ In order to attach an Elastic Inference accelerator to your endpoint provide the
273276
initial_instance_count=1,
274277
accelerator_type='ml.eia2.medium')
275278
279+
Model Directory Structure
280+
=========================
281+
282+
In general, if you use the same version of PyTorch for both training and inference with the SageMaker Python SDK,
283+
the SDK should take care of ensuring that the contents of your ``model.tar.gz`` file are organized correctly.
284+
285+
For versions 1.2 and higher
286+
---------------------------
287+
288+
For PyTorch versions 1.2 and higher, the contents of ``model.tar.gz`` should be organized as follows:
289+
290+
- Model files in the top-level directory
291+
- Inference script (and any other source files) in a directory named ``code/`` (for more about the inference script, see `The SageMaker PyTorch Model Server <#the-sagemaker-pytorch-model-server>`_)
292+
- Optional requirements file located at ``code/requirements.txt`` (for more about requirements files, see `Using third-party libraries <#using-third-party-libraries>`_)
293+
294+
For example:
295+
296+
.. code::
297+
298+
model.tar.gz/
299+
|- model.pth
300+
|- code/
301+
|- inference.py
302+
|- requirements.txt # only for versions 1.3.1 and higher
303+
304+
In this example, ``model.pth`` is the model file saved from training, ``inference.py`` is the inference script, and ``requirements.txt`` is a requirements file.
305+
306+
The ``PyTorch`` and ``PyTorchModel`` classes repack ``model.tar.gz`` to include the inference script (and related files),
307+
as long as the ``framework_version`` is set to 1.2 or higher.
308+
309+
For versions 1.1 and lower
310+
--------------------------
311+
312+
For PyTorch versions 1.1 and lower, ``model.tar.gz`` should contain only the model files,
313+
while your inference script and optional requirements file are packed in a separate tarball, named ``sourcedir.tar.gz`` by default.
314+
315+
For example:
316+
317+
.. code::
318+
319+
model.tar.gz/
320+
|- model.pth
321+
322+
sourcedir.tar.gz/
323+
|- script.py
324+
|- requirements.txt
325+
326+
In this example, ``model.pth`` is the model file saved from training, ``script.py`` is the inference script, and ``requirements.txt`` is a requirements file.
327+
276328
The SageMaker PyTorch Model Server
277329
==================================
278330

@@ -435,7 +487,7 @@ The example below shows a custom ``input_fn`` for preparing pickled torch.Tensor
435487
436488
437489
Get Predictions from a PyTorch Model
438-
------------------------------------
490+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
439491

440492
After the inference request has been deserialized by ``input_fn``, the SageMaker PyTorch model server invokes
441493
``predict_fn`` on the return value of ``input_fn``.

0 commit comments

Comments
 (0)