Skip to content

Commit b1b664c

Browse files
ajaykarpurknakad
authored andcommitted
fix: stop overwriting custom rules volume and type
1 parent c05abf0 commit b1b664c

15 files changed

+465
-438
lines changed

src/sagemaker/debugger.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
and alerts when it detects errors during training.
2121
"""
2222
from __future__ import absolute_import
23-
import smdebug_rulesconfig as rule_configs # noqa: F401 # pylint: disable=unused-import
23+
24+
# TODO-reinvent-2019 [knakad]: Uncomment this once PyPI integration is complete post-re:Invent-2019
25+
# import smdebug_rulesconfig as rule_configs # noqa: F401 # pylint: disable=unused-import
2426

2527

2628
RULES_ECR_REPO_NAME = "sagemaker-debugger-rules"
@@ -191,11 +193,11 @@ def custom(
191193
name,
192194
image_uri,
193195
instance_type,
196+
volume_size_in_gb,
194197
source=None,
195198
rule_to_invoke=None,
196199
container_local_output_path=None,
197200
s3_output_path=None,
198-
volume_size_in_gb=None,
199201
other_trials_s3_input_paths=None,
200202
rule_parameters=None,
201203
collections_to_save=None,
@@ -210,14 +212,15 @@ def custom(
210212
image_uri (str): The URI of the image to be used by the debugger rule.
211213
instance_type (str): Type of EC2 instance to use, for example,
212214
'ml.c4.xlarge'.
215+
volume_size_in_gb (int): Size in GB of the EBS volume
216+
to use for storing data.
213217
source (str): A source file containing a rule to invoke. If provided,
214-
you must also provide rule_to_invoke.
218+
you must also provide rule_to_invoke. This can either be an S3 uri or
219+
a local path.
215220
rule_to_invoke (str): The name of the rule to invoke within the source.
216221
If provided, you must also provide source.
217222
container_local_output_path (str): The path in the container.
218223
s3_output_path (str): The location in S3 to store the output.
219-
volume_size_in_gb (int): Size in GB of the EBS volume
220-
to use for storing data.
221224
other_trials_s3_input_paths ([str]): S3 input paths for other trials.
222225
rule_parameters (dict): A dictionary of parameters for the rule.
223226
collections_to_save ([sagemaker.debugger.CollectionConfig]): A list
@@ -293,7 +296,7 @@ class DebuggerHookConfig(object):
293296

294297
def __init__(
295298
self,
296-
s3_output_path,
299+
s3_output_path=None,
297300
container_local_output_path=None,
298301
hook_parameters=None,
299302
collection_configs=None,

src/sagemaker/estimator.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,9 @@
2727
import sagemaker
2828
from sagemaker import git_utils
2929
from sagemaker.analytics import TrainingJobAnalytics
30-
from sagemaker.debugger import (
31-
DebuggerHookConfig,
32-
TensorBoardOutputConfig,
33-
get_rule_container_image_uri,
34-
)
30+
from sagemaker.debugger import DebuggerHookConfig
31+
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
32+
from sagemaker.debugger import get_rule_container_image_uri
3533
from sagemaker.s3 import S3Uploader
3634

3735
from sagemaker.fw_utils import (
@@ -331,6 +329,9 @@ def _prepare_for_training(self, job_name=None):
331329
# Prepare rules and debugger configs for training.
332330
if self.rules and not self.debugger_hook_config:
333331
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
332+
# If an object was provided without an S3 URI is not provided, default it for the customer.
333+
if self.debugger_hook_config and not self.debugger_hook_config.s3_output_path:
334+
self.debugger_hook_config.s3_output_path = self.output_path
334335
self._prepare_rules()
335336
self._prepare_collection_configs()
336337

@@ -340,17 +341,13 @@ def _prepare_rules(self):
340341
if self.rules is not None:
341342
# Iterate through each of the provided rules.
342343
for rule in self.rules:
343-
# Set the instance type and volume size using the Estimator's defaults.
344344
# Set the image URI using the default rule evaluator image and the region.
345345
if rule.image_uri == "DEFAULT_RULE_EVALUATOR_IMAGE":
346346
rule.image_uri = get_rule_container_image_uri(
347347
self.sagemaker_session.boto_region_name
348348
)
349349
rule.instance_type = None
350350
rule.volume_size_in_gb = None
351-
else:
352-
rule.instance_type = self.train_instance_type
353-
rule.volume_size_in_gb = self.train_volume_size
354351
# If source was provided as a rule parameter, upload to S3 and save the S3 uri.
355352
if "source_s3_uri" in (rule.rule_parameters or {}):
356353
parse_result = urlparse(rule.rule_parameters["source_s3_uri"])
@@ -384,6 +381,42 @@ def _prepare_collection_configs(self):
384381
if self.debugger_hook_config is not None:
385382
self.collection_configs.update(self.debugger_hook_config.collection_configs or [])
386383

384+
def get_debugger_artifacts_path(self):
385+
"""Gets the path to the DebuggerHookConfig output artifacts.
386+
387+
Returns:
388+
str: An S3 path to the output artifacts.
389+
"""
390+
self._ensure_latest_training_job(
391+
error_message="""Cannot get the Debugger artifacts path.
392+
The Estimator is not associated with a training job."""
393+
)
394+
if self.debugger_hook_config is not None:
395+
return os.path.join(
396+
self.debugger_hook_config.s3_output_path,
397+
self.latest_training_job.name,
398+
"debug-output",
399+
)
400+
return None
401+
402+
def get_tensorboard_artifacts_path(self):
403+
"""Gets the path to the TensorBoardOutputConfig output artifacts.
404+
405+
Returns:
406+
str: An S3 path to the output artifacts.
407+
"""
408+
self._ensure_latest_training_job(
409+
error_message="""Cannot get the TensorBoard artifacts path.
410+
The Estimator is not associated with a training job."""
411+
)
412+
if self.debugger_hook_config is not None:
413+
return os.path.join(
414+
self.tensorboard_output_config.s3_output_path,
415+
self.latest_training_job.name,
416+
"tensorboard-output",
417+
)
418+
return None
419+
387420
def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None):
388421
"""Train a model using the input training dataset.
389422
@@ -1626,10 +1659,6 @@ def _prepare_for_training(self, job_name=None):
16261659
# Set defaults for debugging.
16271660
if self.debugger_hook_config is None:
16281661
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
1629-
if self.tensorboard_output_config is None:
1630-
self.tensorboard_output_config = TensorBoardOutputConfig(
1631-
s3_output_path=self.output_path
1632-
)
16331662

16341663
def _stage_user_code_in_s3(self):
16351664
"""Upload the user training script to s3 and return the location.

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,7 @@ def suggest_baseline(
10521052
script. This can be a local path or an S3 uri.
10531053
output_s3_uri (str): Desired S3 destination Destination of the constraint_violations
10541054
and statistics json files.
1055-
Default: "s3://" + default_session_bucket + job_name + output
1055+
Default: "s3://<default_session_bucket>/<job_name>/output"
10561056
wait (bool): Whether the call should wait until the job completes (default: True).
10571057
logs (bool): Whether to show the logs produced by the job.
10581058
Only meaningful when wait is True (default: True).
@@ -1193,6 +1193,7 @@ def create_monitoring_schedule(
11931193
script. This can be a local path or an S3 uri.
11941194
output_s3_uri (str): Desired S3 destination of the constraint_violations and
11951195
statistics json files.
1196+
Default: "s3://<default_session_bucket>/<job_name>/output"
11961197
constraints (sagemaker.model_monitor.Constraints or str): If provided alongside
11971198
statistics, these will be used for monitoring the endpoint. This can be a
11981199
sagemaker.model_monitor.Constraints object or an s3_uri pointing to a constraints

src/sagemaker/sklearn/processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def __init__(
6666
Example: ["python3", "-v"]. If not provided, ["python3"] or ["python2"]
6767
will be chosen based on the py_version parameter.
6868
py_version (str): The python version to use, for example, 'py3'.
69-
volume_size_in_gb (int): Size in GB of the EBS volume
70-
to use for storing data during processing (default: 30).
69+
volume_size_in_gb (int): Size in GB of the EBS volume to
70+
use for storing data during processing (default: 30).
7171
volume_kms_key (str): A KMS key for the processing
7272
volume.
7373
output_kms_key (str): The KMS key id for all ProcessingOutputs.

0 commit comments

Comments
 (0)