Skip to content

Commit e3dea81

Browse files
iyerr3knakad
authored andcommitted
feature: add XGBoost Estimator as new framework (#980)
This PR adds a new XGBoost estimator to release XGBoost as a framework container.
1 parent 435947c commit e3dea81

File tree

12 files changed

+1210
-32
lines changed

12 files changed

+1210
-32
lines changed

README.rst

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,19 @@ Table of Contents
4949
5. `Chainer SageMaker Estimators <#chainer-sagemaker-estimators>`__
5050
6. `PyTorch SageMaker Estimators <#pytorch-sagemaker-estimators>`__
5151
7. `Scikit-learn SageMaker Estimators <#scikit-learn-sagemaker-estimators>`__
52-
8. `SageMaker Reinforcement Learning Estimators <#sagemaker-reinforcement-learning-estimators>`__
53-
9. `SageMaker SparkML Serving <#sagemaker-sparkml-serving>`__
54-
10. `AWS SageMaker Estimators <#aws-sagemaker-estimators>`__
55-
11. `Using SageMaker AlgorithmEstimators <https://sagemaker.readthedocs.io/en/stable/overview.html#using-sagemaker-algorithmestimators>`__
56-
12. `Consuming SageMaker Model Packages <https://sagemaker.readthedocs.io/en/stable/overview.html#consuming-sagemaker-model-packages>`__
57-
13. `BYO Docker Containers with SageMaker Estimators <https://sagemaker.readthedocs.io/en/stable/overview.html#byo-docker-containers-with-sagemaker-estimators>`__
58-
14. `SageMaker Automatic Model Tuning <https://sagemaker.readthedocs.io/en/stable/overview.html#sagemaker-automatic-model-tuning>`__
59-
15. `SageMaker Batch Transform <https://sagemaker.readthedocs.io/en/stable/overview.html#sagemaker-batch-transform>`__
60-
16. `Secure Training and Inference with VPC <https://sagemaker.readthedocs.io/en/stable/overview.html#secure-training-and-inference-with-vpc>`__
61-
17. `BYO Model <https://sagemaker.readthedocs.io/en/stable/overview.html#byo-model>`__
62-
18. `Inference Pipelines <https://sagemaker.readthedocs.io/en/stable/overview.html#inference-pipelines>`__
63-
19. `SageMaker Workflow <#sagemaker-workflow>`__
52+
8. `XGBoost SageMaker Estimators <#xgboost-sagemaker-estimators>`__
53+
9. `SageMaker Reinforcement Learning Estimators <#sagemaker-reinforcement-learning-estimators>`__
54+
10. `SageMaker SparkML Serving <#sagemaker-sparkml-serving>`__
55+
11. `AWS SageMaker Estimators <#aws-sagemaker-estimators>`__
56+
12. `Using SageMaker AlgorithmEstimators <https://sagemaker.readthedocs.io/en/stable/overview.html#using-sagemaker-algorithmestimators>`__
57+
13. `Consuming SageMaker Model Packages <https://sagemaker.readthedocs.io/en/stable/overview.html#consuming-sagemaker-model-packages>`__
58+
14. `BYO Docker Containers with SageMaker Estimators <https://sagemaker.readthedocs.io/en/stable/overview.html#byo-docker-containers-with-sagemaker-estimators>`__
59+
15. `SageMaker Automatic Model Tuning <https://sagemaker.readthedocs.io/en/stable/overview.html#sagemaker-automatic-model-tuning>`__
60+
16. `SageMaker Batch Transform <https://sagemaker.readthedocs.io/en/stable/overview.html#sagemaker-batch-transform>`__
61+
17. `Secure Training and Inference with VPC <https://sagemaker.readthedocs.io/en/stable/overview.html#secure-training-and-inference-with-vpc>`__
62+
18. `BYO Model <https://sagemaker.readthedocs.io/en/stable/overview.html#byo-model>`__
63+
19. `Inference Pipelines <https://sagemaker.readthedocs.io/en/stable/overview.html#inference-pipelines>`__
64+
20. `SageMaker Workflow <#sagemaker-workflow>`__
6465

6566

6667
Installing the SageMaker Python SDK
@@ -247,6 +248,21 @@ For more information about Scikit-learn SageMaker Estimators, see `Using Scikit-
247248

248249
.. _Using Scikit-learn with the SageMaker Python SDK: https://sagemaker.readthedocs.io/en/stable/using_sklearn.html
249250

251+
XGBoost SageMaker Estimators
252+
----------------------------
253+
254+
With XGBoost SageMaker Estimators, you can train and host XGBoost models on Amazon SageMaker.
255+
256+
Supported versions of XGBoost: ``0.90-1``.
257+
258+
We recommend that you use the latest supported version, because that's where we focus most of our development efforts.
259+
260+
For more information about XGBoost, see https://xgboost.readthedocs.io/en/latest/
261+
262+
For more information about XGBoost SageMaker Estimators, see `Using XGBoost with the SageMaker Python SDK`_.
263+
264+
.. _Using XGBoost with the SageMaker Python SDK: https://sagemaker.readthedocs.io/en/stable/using_xgboost.html
265+
250266

251267
SageMaker Reinforcement Learning Estimators
252268
-------------------------------------------

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from sagemaker.model import NEO_IMAGE_ACCOUNT
2727
from sagemaker.session import s3_input
2828
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
29+
from sagemaker.xgboost.estimator import get_xgboost_image_uri
30+
from sagemaker.xgboost.defaults import XGBOOST_LATEST_VERSION
2931

3032
logger = logging.getLogger(__name__)
3133

@@ -479,5 +481,15 @@ def get_image_uri(region_name, repo_name, repo_version=1):
479481
repo_name:
480482
repo_version:
481483
"""
484+
if repo_name == "xgboost":
485+
if repo_version in ["0.90", "0.90-1", "0.90-1-cpu-py3"]:
486+
return get_xgboost_image_uri(region_name, XGBOOST_LATEST_VERSION)
487+
logging.warning(
488+
"There is a more up to date SageMaker XGBoost image."
489+
"To use the newer image, please set 'repo_version'="
490+
"'0.90-1. For example:\n"
491+
"\tget_image_uri(region, 'xgboost', %s).",
492+
XGBOOST_LATEST_VERSION,
493+
)
482494
repo = "{}:{}".format(repo_name, repo_version)
483495
return "{}/{}".format(registry(region_name, repo_name), repo)

src/sagemaker/fw_registry.py

Lines changed: 95 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,101 @@
1717
from sagemaker.utils import get_ecr_image_uri_prefix
1818

1919
image_registry_map = {
20-
"us-west-1": {"sparkml-serving": "746614075791", "scikit-learn": "746614075791"},
21-
"us-west-2": {"sparkml-serving": "246618743249", "scikit-learn": "246618743249"},
22-
"us-east-1": {"sparkml-serving": "683313688378", "scikit-learn": "683313688378"},
23-
"us-east-2": {"sparkml-serving": "257758044811", "scikit-learn": "257758044811"},
24-
"ap-northeast-1": {"sparkml-serving": "354813040037", "scikit-learn": "354813040037"},
25-
"ap-northeast-2": {"sparkml-serving": "366743142698", "scikit-learn": "366743142698"},
26-
"ap-southeast-1": {"sparkml-serving": "121021644041", "scikit-learn": "121021644041"},
27-
"ap-southeast-2": {"sparkml-serving": "783357654285", "scikit-learn": "783357654285"},
28-
"ap-south-1": {"sparkml-serving": "720646828776", "scikit-learn": "720646828776"},
29-
"eu-west-1": {"sparkml-serving": "141502667606", "scikit-learn": "141502667606"},
30-
"eu-west-2": {"sparkml-serving": "764974769150", "scikit-learn": "764974769150"},
31-
"eu-central-1": {"sparkml-serving": "492215442770", "scikit-learn": "492215442770"},
32-
"ca-central-1": {"sparkml-serving": "341280168497", "scikit-learn": "341280168497"},
33-
"us-gov-west-1": {"sparkml-serving": "414596584902", "scikit-learn": "414596584902"},
34-
"us-iso-east-1": {"sparkml-serving": "833128469047", "scikit-learn": "833128469047"},
35-
"ap-east-1": {"sparkml-serving": "651117190479", "scikit-learn": "651117190479"},
36-
"sa-east-1": {"sparkml-serving": "737474898029", "scikit-learn": "737474898029"},
37-
"eu-north-1": {"sparkml-serving": "662702820516", "scikit-learn": "662702820516"},
38-
"eu-west-3": {"sparkml-serving": "659782779980", "scikit-learn": "659782779980"},
20+
"us-west-1": {
21+
"sparkml-serving": "746614075791",
22+
"scikit-learn": "746614075791",
23+
"xgboost": "746614075791",
24+
},
25+
"us-west-2": {
26+
"sparkml-serving": "246618743249",
27+
"scikit-learn": "246618743249",
28+
"xgboost": "246618743249",
29+
},
30+
"us-east-1": {
31+
"sparkml-serving": "683313688378",
32+
"scikit-learn": "683313688378",
33+
"xgboost": "683313688378",
34+
},
35+
"us-east-2": {
36+
"sparkml-serving": "257758044811",
37+
"scikit-learn": "257758044811",
38+
"xgboost": "257758044811",
39+
},
40+
"ap-northeast-1": {
41+
"sparkml-serving": "354813040037",
42+
"scikit-learn": "354813040037",
43+
"xgboost": "354813040037",
44+
},
45+
"ap-northeast-2": {
46+
"sparkml-serving": "366743142698",
47+
"scikit-learn": "366743142698",
48+
"xgboost": "366743142698",
49+
},
50+
"ap-southeast-1": {
51+
"sparkml-serving": "121021644041",
52+
"scikit-learn": "121021644041",
53+
"xgboost": "121021644041",
54+
},
55+
"ap-southeast-2": {
56+
"sparkml-serving": "783357654285",
57+
"scikit-learn": "783357654285",
58+
"xgboost": "783357654285",
59+
},
60+
"ap-south-1": {
61+
"sparkml-serving": "720646828776",
62+
"scikit-learn": "720646828776",
63+
"xgboost": "720646828776",
64+
},
65+
"eu-west-1": {
66+
"sparkml-serving": "141502667606",
67+
"scikit-learn": "141502667606",
68+
"xgboost": "141502667606",
69+
},
70+
"eu-west-2": {
71+
"sparkml-serving": "764974769150",
72+
"scikit-learn": "764974769150",
73+
"xgboost": "764974769150",
74+
},
75+
"eu-central-1": {
76+
"sparkml-serving": "492215442770",
77+
"scikit-learn": "492215442770",
78+
"xgboost": "492215442770",
79+
},
80+
"ca-central-1": {
81+
"sparkml-serving": "341280168497",
82+
"scikit-learn": "341280168497",
83+
"xgboost": "341280168497",
84+
},
85+
"us-gov-west-1": {
86+
"sparkml-serving": "414596584902",
87+
"scikit-learn": "414596584902",
88+
"xgboost": "414596584902",
89+
},
90+
"us-iso-east-1": {
91+
"sparkml-serving": "833128469047",
92+
"scikit-learn": "833128469047",
93+
"xgboost": "833128469047",
94+
},
95+
"ap-east-1": {
96+
"sparkml-serving": "651117190479",
97+
"scikit-learn": "651117190479",
98+
"xgboost": "651117190479",
99+
},
100+
"sa-east-1": {
101+
"sparkml-serving": "737474898029",
102+
"scikit-learn": "737474898029",
103+
"xgboost": "737474898029",
104+
},
105+
"eu-north-1": {
106+
"sparkml-serving": "662702820516",
107+
"scikit-learn": "662702820516",
108+
"xgboost": "662702820516",
109+
},
110+
"eu-west-3": {
111+
"sparkml-serving": "659782779980",
112+
"scikit-learn": "659782779980",
113+
"xgboost": "659782779980",
114+
},
39115
}
40116

41117

src/sagemaker/fw_utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
"framework_version is required for script mode estimator. "
5050
"Please add framework_version={} to your constructor to avoid this error."
5151
)
52+
UNSUPPORTED_FRAMEWORK_VERSION_ERROR = (
53+
"{} framework does not support version {}. Please use one of the following: {}."
54+
)
5255

5356
VALID_PY_VERSIONS = ["py2", "py3"]
5457
VALID_EIA_FRAMEWORKS = ["tensorflow", "tensorflow-serving", "mxnet", "mxnet-serving"]
@@ -359,7 +362,7 @@ def framework_name_from_image(image_name):
359362
# extract framework, python version and image tag
360363
# We must support both the legacy and current image name format.
361364
name_pattern = re.compile(
362-
r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501 # pylint: disable=line-too-long
365+
r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501 # pylint: disable=line-too-long
363366
)
364367
legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")
365368

@@ -436,6 +439,25 @@ def empty_framework_version_warning(default_version, latest_version):
436439
return " ".join(msgs)
437440

438441

442+
def get_unsupported_framework_version_error(
443+
framework_name, unsupported_version, supported_versions
444+
):
445+
"""Return error message for unsupported framework version.
446+
447+
This should also return the supported versions for customers.
448+
449+
:param framework_name:
450+
:param unsupported_version:
451+
:param supported_versions:
452+
:return:
453+
"""
454+
return UNSUPPORTED_FRAMEWORK_VERSION_ERROR.format(
455+
framework_name,
456+
unsupported_version,
457+
", ".join('"{}"'.format(version) for version in supported_versions),
458+
)
459+
460+
439461
def python_deprecation_warning(framework):
440462
"""
441463
Args:

src/sagemaker/xgboost/README.rst

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
============================================
2+
XGBoost SageMaker Estimators and Models
3+
============================================
4+
5+
With XGBoost Estimators, you can train and host XGBoost models on Amazon SageMaker.
6+
7+
Supported versions of SageMaker XGBoost: ``0.90-1``
8+
9+
Note that the first part of the version refers to the upstream module version (aka, 0.90), while the second
10+
part refers to the SageMaker version for the container.
11+
12+
You can visit the XGBoost repository at https://github.com/dmlc/xgboost
13+
14+
For information about using XGBoost with the SageMaker Python SDK, see https://sagemaker.readthedocs.io/en/stable/using_xgboost.html.
15+
16+
XGBoost Training Examples
17+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
18+
19+
Amazon provides an example Jupyter notebook that demonstrate end-to-end training on Amazon SageMaker using XGBoost.
20+
Please refer to:
21+
22+
https://github.com/awslabs/amazon-sagemaker-examples/tree/master/sagemaker-python-sdk
23+
24+
These are also available in SageMaker Notebook Instance hosted Jupyter notebooks under the "sample notebooks" folder.
25+
26+
27+
SageMaker XGBoost Docker Containers
28+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
29+
30+
When training and deploying training scripts, SageMaker runs your Python script in a Docker container with several
31+
libraries installed. When creating the Estimator and calling deploy to create the SageMaker Endpoint, you can control
32+
the environment your script runs in.
33+
34+
SageMaker runs XGBoost Estimator scripts in either Python 2.7 or Python 3.5. You can select the Python version by
35+
passing a py_version keyword arg to the XGBoost Estimator constructor. Setting this to py3 (the default) will cause
36+
your training script to be run on Python 3.5. Setting this to py2 will cause your training script to be run on Python 2.7
37+
This Python version applies to both the Training Job, created by fit, and the Endpoint, created by deploy.
38+
39+
The XGBoost Docker images have the following dependencies installed:
40+
41+
+-----------------------------+-------------+
42+
| Dependencies | Version |
43+
+-----------------------------+-------------+
44+
| xgboost | 0.90.0 |
45+
+-----------------------------+-------------+
46+
| matplotlib | 3.0.3+ |
47+
+-----------------------------+-------------+
48+
| numpy | 1.16.4+ |
49+
+-----------------------------+-------------+
50+
| pandas | 0.24.2+ |
51+
+-----------------------------+-------------+
52+
| psutils | 5.6.3+ |
53+
+-----------------------------+-------------+
54+
| PyYAML | < 4.3 |
55+
+-----------------------------+-------------+
56+
| requests | < 2.21 |
57+
+-----------------------------+-------------+
58+
| retrying | 1.3.3 |
59+
+-----------------------------+-------------+
60+
| scikit-learn | 0.21.2+ |
61+
+-----------------------------+-------------+
62+
| scipy | 1.3.0+ |
63+
+-----------------------------+-------------+
64+
| sagemaker-containers | 2.5.1+ |
65+
+-----------------------------+-------------+
66+
| urllib3 | < 1.25 |
67+
+-----------------------------+-------------+
68+
| Python | 2.7 or 3.5 |
69+
+-----------------------------+-------------+
70+
71+
You can see the full list by calling ``pip freeze`` from the running Docker image.
72+
73+
The Docker images extend Ubuntu 16.04.
74+
75+
You can select version of XGBoost by passing a framework_version keyword arg to the XGBoost Estimator constructor.
76+
Currently supported versions are listed in the above table. You can also set framework_version to only specify major and
77+
minor version, which will cause your training script to be run on the latest supported patch version of that minor
78+
version.
79+
80+
Alternatively, you can build your own image by following the instructions in the SageMaker XGBoost containers
81+
repository, and passing ``image_name`` to the XGBoost Estimator constructor.
82+
83+
You can visit the SageMaker XGBoost containers repository here: https://github.com/aws/sagemaker-xgboost-container

src/sagemaker/xgboost/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from sagemaker.xgboost.defaults import XGBOOST_NAME, XGBOOST_LATEST_VERSION # noqa: F401
15+
from sagemaker.xgboost.estimator import XGBoost # noqa: F401
16+
from sagemaker.xgboost.model import XGBoostModel, XGBoostPredictor # noqa: F401

src/sagemaker/xgboost/defaults.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
XGBOOST_NAME = "xgboost"
17+
XGBOOST_LATEST_VERSION = "0.90-1"
18+
XGBOOST_SUPPORTED_VERSIONS = [XGBOOST_LATEST_VERSION]

0 commit comments

Comments
 (0)