Skip to content

Commit 51e0305

Browse files
nadiayaJonathan Esterhazy
authored andcommitted
add RLEstimator
1 parent 28ac404 commit 51e0305

File tree

17 files changed

+1663
-8
lines changed

17 files changed

+1663
-8
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.16.0.dev
6+
==========
7+
8+
* feature: Estimators: Add RLEstimator to provide support for Reinforcement Learning.
9+
510
1.15.2
611
======
712

doc/index.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ A managed environment for TensorFlow training and hosting on Amazon SageMaker
3939

4040
sagemaker.tensorflow
4141

42+
Reinforcement Learning
43+
----------------------
44+
A managed environment for Reinforcement Learning training and hosting on Amazon SageMaker
45+
46+
.. toctree::
47+
:maxdepth: 2
48+
49+
sagemaker.rl
50+
4251
SageMaker First-Party Algorithms
4352
--------------------------------
4453
Amazon provides implementations of some common machine learning algortithms optimized for GPU architecture and massive datasets.

setup.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import re
1717
from glob import glob
18+
import sys
1819

1920
from setuptools import setup, find_packages
2021

@@ -31,6 +32,15 @@ def read(fname):
3132
return open(os.path.join(os.path.dirname(__file__), fname)).read()
3233

3334

35+
# Declare minimal set for installation
36+
required_packages = ['boto3>=1.9.45', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0',
37+
'urllib3>=1.21', 'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5',
38+
'docker-compose>=1.23.0']
39+
40+
# enum is introduced in Python 3.4. Installing enum back port
41+
if sys.version_info < (3, 4):
42+
required_packages.append('enum34>=1.1.6')
43+
3444
setup(name="sagemaker",
3545
version=get_version(),
3646
description="Open source library for training and deploying models on Amazon SageMaker.",
@@ -52,10 +62,7 @@ def read(fname):
5262
"Programming Language :: Python :: 3.5",
5363
],
5464

55-
# Declare minimal set for installation
56-
install_requires=['boto3>=1.9.45', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0',
57-
'urllib3 >=1.21', 'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5',
58-
'docker-compose>=1.23.0'],
65+
install_requires=required_packages,
5966

6067
extras_require={
6168
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist',

src/sagemaker/fw_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,22 +178,24 @@ def framework_name_from_image(image_name):
178178
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>-<py_ver>-<device>:<fw_version>-<device>-<py_ver>'
179179
current:
180180
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>:<fw_version>-<device>-<py_ver>'
181+
current:
182+
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-rl-<fw>:<rl_toolkit><rl_version>-<device>-<py_ver>'
181183
182184
Returns:
183185
tuple: A tuple containing:
184186
str: The framework name
185187
str: The Python version
186188
str: The image tag
187189
"""
188-
# image name format: <account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<framework>-<py_ver>-<device>:<tag>
189190
sagemaker_pattern = re.compile(r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)amazonaws.com(/)(.*:.*)$')
190191
sagemaker_match = sagemaker_pattern.match(image_name)
191192
if sagemaker_match is None:
192193
return None, None, None
193194
else:
194195
# extract framework, python version and image tag
195196
# We must support both the legacy and current image name format.
196-
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet|chainer|pytorch|scikit-learn):(.*?)-(.*?)-(py2|py3)$')
197+
name_pattern = \
198+
re.compile('^sagemaker(?:-rl)?-(tensorflow|mxnet|chainer|pytorch|scikit-learn):(.*)-(.*?)-(py2|py3)$')
197199
legacy_name_pattern = re.compile('^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$')
198200
name_match = name_pattern.match(sagemaker_match.group(8))
199201
legacy_match = legacy_name_pattern.match(sagemaker_match.group(8))

src/sagemaker/rl/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker.rl.estimator import (RLEstimator, RLToolkit, RLFramework,
16+
TOOLKIT_FRAMEWORK_VERSION_MAP)
17+
18+
__all__ = [RLEstimator, RLToolkit, RLFramework, TOOLKIT_FRAMEWORK_VERSION_MAP]

0 commit comments

Comments
 (0)