|
21 | 21 | from sagemaker.fw_utils import tar_and_upload_dir |
22 | 22 | from sagemaker.fw_utils import parse_s3_url |
23 | 23 | from sagemaker.fw_utils import UploadedCode |
| 24 | +from sagemaker.local.local_session import LocalSession |
24 | 25 | from sagemaker.model import Model |
25 | 26 | from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME, |
26 | 27 | CONTAINER_LOG_LEVEL_PARAM_NAME, JOB_NAME_PARAM_NAME, SAGEMAKER_REGION_PARAM_NAME) |
@@ -78,7 +79,17 @@ def __init__(self, role, train_instance_count, train_instance_type, |
78 | 79 | self.train_volume_size = train_volume_size |
79 | 80 | self.train_max_run = train_max_run |
80 | 81 | self.input_mode = input_mode |
81 | | - self.sagemaker_session = sagemaker_session or Session() |
| 82 | + |
| 83 | + if self.train_instance_type in ('local', 'local_gpu'): |
| 84 | + self.local_mode = True |
| 85 | + if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1: |
| 86 | + raise RuntimeError("Distributed Training in Local GPU is not supported") |
| 87 | + |
| 88 | + self.sagemaker_session = LocalSession() |
| 89 | + else: |
| 90 | + self.local_mode = False |
| 91 | + self.sagemaker_session = sagemaker_session or Session() |
| 92 | + |
82 | 93 | self.base_job_name = base_job_name |
83 | 94 | self._current_job_name = None |
84 | 95 | self.output_path = output_path |
@@ -303,7 +314,7 @@ def start_new(cls, estimator, inputs): |
303 | 314 | """Create a new Amazon SageMaker training job from the estimator. |
304 | 315 |
|
305 | 316 | Args: |
306 | | - estimator (sagemaker.estimator.Framework): Estimator object created by the user. |
| 317 | + estimator (sagemaker.estimator.EstimatorBase): Estimator object created by the user. |
307 | 318 | inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`. |
308 | 319 |
|
309 | 320 | Returns: |
|
0 commit comments