Skip to content

Commit 8f01c40

Browse files
committed
Minor processing changes (#282)
* fix: default 'command' in SKLearnProcessor * change: make to_request_dict private * doc: modifying error message
1 parent f89b7d1 commit 8f01c40

File tree

13 files changed

+67
-42
lines changed

13 files changed

+67
-42
lines changed

src/sagemaker/debugger.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def __init__(
314314
self.hook_parameters = hook_parameters
315315
self.collection_configs = collection_configs
316316

317-
def to_request_dict(self):
317+
def _to_request_dict(self):
318318
"""Generates a request dictionary using the parameters provided
319319
when initializing the object.
320320
@@ -331,7 +331,8 @@ def to_request_dict(self):
331331

332332
if self.collection_configs is not None:
333333
debugger_hook_config_request["CollectionConfigurations"] = [
334-
collection_config.to_request_dict() for collection_config in self.collection_configs
334+
collection_config._to_request_dict()
335+
for collection_config in self.collection_configs
335336
]
336337

337338
return debugger_hook_config_request
@@ -353,7 +354,7 @@ def __init__(self, s3_output_path, container_local_output_path=None):
353354
self.s3_output_path = s3_output_path
354355
self.container_local_output_path = container_local_output_path
355356

356-
def to_request_dict(self):
357+
def _to_request_dict(self):
357358
"""Generates a request dictionary using the parameters provided
358359
when initializing the object.
359360
@@ -401,7 +402,7 @@ def __ne__(self, other):
401402
def __hash__(self):
402403
return hash((self.name, tuple(sorted((self.parameters or {}).items()))))
403404

404-
def to_request_dict(self):
405+
def _to_request_dict(self):
405406
"""Generates a request dictionary using the parameters provided
406407
when initializing the object.
407408

src/sagemaker/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -965,12 +965,12 @@ def start_new(cls, estimator, inputs, experiment_config):
965965

966966
if estimator.debugger_hook_config:
967967
estimator.debugger_hook_config.collection_configs = estimator.collection_configs
968-
train_args["debugger_hook_config"] = estimator.debugger_hook_config.to_request_dict()
968+
train_args["debugger_hook_config"] = estimator.debugger_hook_config._to_request_dict()
969969

970970
if estimator.tensorboard_output_config:
971971
train_args[
972972
"tensorboard_output_config"
973-
] = estimator.tensorboard_output_config.to_request_dict()
973+
] = estimator.tensorboard_output_config._to_request_dict()
974974

975975
cls._add_spot_checkpoint_args(local_mode, estimator, train_args)
976976

src/sagemaker/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def deploy(
460460

461461
data_capture_config_dict = None
462462
if data_capture_config is not None:
463-
data_capture_config_dict = data_capture_config.to_request_dict()
463+
data_capture_config_dict = data_capture_config._to_request_dict()
464464

465465
if update_endpoint:
466466
endpoint_config_name = self.sagemaker_session.create_endpoint_config(

src/sagemaker/model_monitor/data_capture_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
self.csv_content_types = csv_content_types or ["text/csv"]
7272
self.json_content_types = json_content_types or ["application/json"]
7373

74-
def to_request_dict(self):
74+
def _to_request_dict(self):
7575
"""Generates a request dictionary using the parameters provided to the class."""
7676
request_dict = {
7777
"EnableCapture": self.enable_capture,

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def create_monitoring_schedule(
297297
constraints_s3_uri = constraints_object.file_s3_uri
298298

299299
monitoring_output_config = {
300-
"MonitoringOutputs": [normalized_monitoring_output.to_request_dict()]
300+
"MonitoringOutputs": [normalized_monitoring_output._to_request_dict()]
301301
}
302302

303303
if self.output_kms_key is not None:
@@ -310,14 +310,14 @@ def create_monitoring_schedule(
310310

311311
network_config_dict = None
312312
if self.network_config is not None:
313-
network_config_dict = self.network_config.to_request_dict()
313+
network_config_dict = self.network_config._to_request_dict()
314314

315315
self.sagemaker_session.create_monitoring_schedule(
316316
monitoring_schedule_name=self.monitoring_schedule_name,
317317
schedule_expression=schedule_cron_expression,
318318
statistics_s3_uri=statistics_s3_uri,
319319
constraints_s3_uri=constraints_s3_uri,
320-
monitoring_inputs=[normalized_endpoint_input.to_request_dict()],
320+
monitoring_inputs=[normalized_endpoint_input._to_request_dict()],
321321
monitoring_output_config=monitoring_output_config,
322322
instance_count=self.instance_count,
323323
instance_type=self.instance_type,
@@ -398,14 +398,14 @@ def update_monitoring_schedule(
398398
monitoring_inputs = None
399399
if endpoint_input is not None:
400400
monitoring_inputs = [
401-
self._normalize_endpoint_input(endpoint_input=endpoint_input).to_request_dict()
401+
self._normalize_endpoint_input(endpoint_input=endpoint_input)._to_request_dict()
402402
]
403403

404404
monitoring_output_config = None
405405
if output is not None:
406406
normalized_monitoring_output = self._normalize_monitoring_output(output=output)
407407
monitoring_output_config = {
408-
"MonitoringOutputs": [normalized_monitoring_output.to_request_dict()]
408+
"MonitoringOutputs": [normalized_monitoring_output._to_request_dict()]
409409
}
410410

411411
statistics_object, constraints_object = self._get_baseline_files(
@@ -459,7 +459,7 @@ def update_monitoring_schedule(
459459

460460
network_config_dict = None
461461
if self.network_config is not None:
462-
network_config_dict = self.network_config.to_request_dict()
462+
network_config_dict = self.network_config._to_request_dict()
463463

464464
self.sagemaker_session.update_monitoring_schedule(
465465
monitoring_schedule_name=self.monitoring_schedule_name,
@@ -1263,22 +1263,22 @@ def create_monitoring_schedule(
12631263
)
12641264

12651265
monitoring_output_config = {
1266-
"MonitoringOutputs": [normalized_monitoring_output.to_request_dict()]
1266+
"MonitoringOutputs": [normalized_monitoring_output._to_request_dict()]
12671267
}
12681268

12691269
if self.output_kms_key is not None:
12701270
monitoring_output_config["KmsKeyId"] = self.output_kms_key
12711271

12721272
network_config_dict = None
12731273
if self.network_config is not None:
1274-
network_config_dict = self.network_config.to_request_dict()
1274+
network_config_dict = self.network_config._to_request_dict()
12751275

12761276
self.sagemaker_session.create_monitoring_schedule(
12771277
monitoring_schedule_name=self.monitoring_schedule_name,
12781278
schedule_expression=schedule_cron_expression,
12791279
constraints_s3_uri=constraints_s3_uri,
12801280
statistics_s3_uri=statistics_s3_uri,
1281-
monitoring_inputs=[normalized_endpoint_input.to_request_dict()],
1281+
monitoring_inputs=[normalized_endpoint_input._to_request_dict()],
12821282
monitoring_output_config=monitoring_output_config,
12831283
instance_count=self.instance_count,
12841284
instance_type=self.instance_type,
@@ -1360,7 +1360,7 @@ def update_monitoring_schedule(
13601360
"""
13611361
monitoring_inputs = None
13621362
if endpoint_input is not None:
1363-
monitoring_inputs = [self._normalize_endpoint_input(endpoint_input).to_request_dict()]
1363+
monitoring_inputs = [self._normalize_endpoint_input(endpoint_input)._to_request_dict()]
13641364

13651365
record_preprocessor_script_s3_uri = None
13661366
if record_preprocessor_script is not None:
@@ -1381,7 +1381,7 @@ def update_monitoring_schedule(
13811381
output_s3_uri=output_s3_uri
13821382
)
13831383
monitoring_output_config = {
1384-
"MonitoringOutputs": [normalized_monitoring_output.to_request_dict()]
1384+
"MonitoringOutputs": [normalized_monitoring_output._to_request_dict()]
13851385
}
13861386
output_path = normalized_monitoring_output.source
13871387

@@ -1428,7 +1428,7 @@ def update_monitoring_schedule(
14281428

14291429
network_config_dict = None
14301430
if self.network_config is not None:
1431-
network_config_dict = self.network_config.to_request_dict()
1431+
network_config_dict = self.network_config._to_request_dict()
14321432

14331433
if role is not None:
14341434
self.role = role
@@ -2054,7 +2054,7 @@ def __init__(
20542054
self.s3_input_mode = s3_input_mode
20552055
self.s3_data_distribution_type = s3_data_distribution_type
20562056

2057-
def to_request_dict(self):
2057+
def _to_request_dict(self):
20582058
"""Generates a request dictionary using the parameters provided to the class."""
20592059
endpoint_input_request = {
20602060
"EndpointInput": {
@@ -2088,7 +2088,7 @@ def __init__(self, source, destination=None, s3_upload_mode="Continuous"):
20882088
self.destination = destination
20892089
self.s3_upload_mode = s3_upload_mode
20902090

2091-
def to_request_dict(self):
2091+
def _to_request_dict(self):
20922092
"""Generates a request dictionary using the parameters provided to the class.
20932093
20942094
Returns:

src/sagemaker/network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, enable_network_isolation=False, security_group_ids=None, subn
3434
self.security_group_ids = security_group_ids
3535
self.subnets = subnets
3636

37-
def to_request_dict(self):
37+
def _to_request_dict(self):
3838
"""Generates a request dictionary using the parameters provided to the class."""
3939
network_config_request = {"EnableNetworkIsolation": self.enable_network_isolation}
4040

src/sagemaker/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def deploy(
148148

149149
data_capture_config_dict = None
150150
if data_capture_config is not None:
151-
data_capture_config_dict = data_capture_config.to_request_dict()
151+
data_capture_config_dict = data_capture_config._to_request_dict()
152152

153153
if update_endpoint:
154154
endpoint_config_name = self.sagemaker_session.create_endpoint_config(

src/sagemaker/predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def update_data_capture_config(self, data_capture_config):
210210

211211
data_capture_config_dict = None
212212
if data_capture_config is not None:
213-
data_capture_config_dict = data_capture_config.to_request_dict()
213+
data_capture_config_dict = data_capture_config._to_request_dict()
214214

215215
self.sagemaker_session.create_endpoint_config_from_existing(
216216
existing_config_name=endpoint_desc["EndpointConfigName"],

src/sagemaker/processing.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,8 @@ def _get_user_script_name(self, code):
403403
"""
404404
if os.path.isdir(code) is None or not os.path.splitext(code)[1]:
405405
raise ValueError(
406-
"""code must be a file, not a directory. Please package your code inside of a .whl
407-
file and pass that in, instead.
406+
"""'code' must be a file, not a directory. Please pass a path to a file, not a
407+
directory.
408408
"""
409409
)
410410
return os.path.basename(code)
@@ -494,10 +494,10 @@ def start_new(cls, processor, inputs, outputs, experiment_config):
494494
process_request_args = {}
495495

496496
# Add arguments to the dictionary.
497-
process_request_args["inputs"] = [input.to_request_dict() for input in inputs]
497+
process_request_args["inputs"] = [input._to_request_dict() for input in inputs]
498498

499499
process_request_args["output_config"] = {
500-
"Outputs": [output.to_request_dict() for output in outputs]
500+
"Outputs": [output._to_request_dict() for output in outputs]
501501
}
502502
if processor.output_kms_key is not None:
503503
process_request_args["output_config"]["KmsKeyId"] = processor.output_kms_key
@@ -534,7 +534,7 @@ def start_new(cls, processor, inputs, outputs, experiment_config):
534534
process_request_args["environment"] = processor.env
535535

536536
if processor.network_config is not None:
537-
process_request_args["network_config"] = processor.network_config.to_request_dict()
537+
process_request_args["network_config"] = processor.network_config._to_request_dict()
538538
else:
539539
process_request_args["network_config"] = None
540540

@@ -614,7 +614,7 @@ def __init__(
614614
self.s3_data_distribution_type = s3_data_distribution_type
615615
self.s3_compression_type = s3_compression_type
616616

617-
def to_request_dict(self):
617+
def _to_request_dict(self):
618618
"""Generates a request dictionary using the parameters provided to the class."""
619619
# Create the request dictionary.
620620
s3_input_request = {
@@ -661,7 +661,7 @@ def __init__(self, source, destination=None, output_name=None, s3_upload_mode="E
661661
self.output_name = output_name
662662
self.s3_upload_mode = s3_upload_mode
663663

664-
def to_request_dict(self):
664+
def _to_request_dict(self):
665665
"""Generates a request dictionary using the parameters provided to the class."""
666666
# Create the request dictionary.
667667
s3_output_request = {

src/sagemaker/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2738,7 +2738,7 @@ def endpoint_from_model_data(
27382738

27392739
data_capture_config_dict = None
27402740
if data_capture_config is not None:
2741-
data_capture_config_dict = data_capture_config.to_request_dict()
2741+
data_capture_config_dict = data_capture_config._to_request_dict()
27422742

27432743
if not _deployment_entity_exists(
27442744
lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name)

0 commit comments

Comments
 (0)