Skip to content

Commit c943b76

Browse files
rubanhRuban Hussainknikure
authored
fix: SDK Defaults Config - Handle config injection for None Sessions (#3878)
Co-authored-by: Ruban Hussain <[email protected]> Co-authored-by: Kalyani Nikure <[email protected]>
1 parent 22cfd1a commit c943b76

File tree

7 files changed

+151
-22
lines changed

7 files changed

+151
-22
lines changed

src/sagemaker/jumpstart/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
JumpStartVersionedModelId,
3939
)
4040
from sagemaker.session import Session
41+
from sagemaker.config import load_sagemaker_config
4142
from sagemaker.utils import resolve_value_from_config
4243
from sagemaker.workflow import is_pipeline_variable
4344

@@ -455,6 +456,9 @@ def resolve_model_intelligent_default_field(
455456
over intelligent defaults. For all other fields, intelligent defaults takes precedence
456457
over the JumpStart default fields.
457458
"""
459+
# In case, sagemaker_session is None, get sagemaker_config from load_sagemaker_config()
460+
# to resolve value from config for the respective field_name parameter
461+
_sagemaker_config = load_sagemaker_config() if (sagemaker_session is None) else None
458462

459463
# We allow customers to define a role which takes precedence
460464
# over intelligent defaults
@@ -464,6 +468,7 @@ def resolve_model_intelligent_default_field(
464468
config_path=MODEL_EXECUTION_ROLE_ARN_PATH,
465469
default_value=default_value or sagemaker_session.get_caller_identity_arn(),
466470
sagemaker_session=sagemaker_session,
471+
sagemaker_config=_sagemaker_config,
467472
)
468473

469474
# JumpStart Models have certain default field values. We want
@@ -474,6 +479,7 @@ def resolve_model_intelligent_default_field(
474479
config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH,
475480
sagemaker_session=sagemaker_session,
476481
default_value=default_value,
482+
sagemaker_config=_sagemaker_config,
477483
)
478484
return resolved_val if resolved_val is not None else field_val
479485

@@ -494,6 +500,11 @@ def resolve_estimator_intelligent_default_field(
494500
over the JumpStart default fields.
495501
"""
496502

503+
# Workaround for config injection if sagemaker_session is None, since in
504+
# that case sagemaker_session will not be initialized until
505+
# `_init_sagemaker_session_if_does_not_exist` is called later
506+
_sagemaker_config = load_sagemaker_config() if (sagemaker_session is None) else None
507+
497508
# We allow customers to define a role which takes precedence
498509
# over intelligent defaults
499510
if field_name == "role":
@@ -502,6 +513,7 @@ def resolve_estimator_intelligent_default_field(
502513
config_path=TRAINING_JOB_ROLE_ARN_PATH,
503514
default_value=default_value or sagemaker_session.get_caller_identity_arn(),
504515
sagemaker_session=sagemaker_session,
516+
sagemaker_config=_sagemaker_config,
505517
)
506518

507519
# JumpStart Estimators have certain default field values. We want
@@ -513,6 +525,7 @@ def resolve_estimator_intelligent_default_field(
513525
config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
514526
sagemaker_session=sagemaker_session,
515527
default_value=default_value,
528+
sagemaker_config=_sagemaker_config,
516529
)
517530
return resolved_val if resolved_val is not None else field_val
518531

@@ -523,6 +536,7 @@ def resolve_estimator_intelligent_default_field(
523536
config_path=TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
524537
sagemaker_session=sagemaker_session,
525538
default_value=default_value,
539+
sagemaker_config=_sagemaker_config,
526540
)
527541
return resolved_val if resolved_val is not None else field_val
528542

src/sagemaker/model.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
MODEL_EXECUTION_ROLE_ARN_PATH,
4242
MODEL_PRIMARY_CONTAINER_ENVIRONMENT_PATH,
4343
ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH,
44+
load_sagemaker_config,
4445
)
4546
from sagemaker.session import Session
4647
from sagemaker.model_metrics import ModelMetrics
@@ -313,13 +314,25 @@ def __init__(
313314
self.name = name
314315
self._base_name = None
315316
self.sagemaker_session = sagemaker_session
317+
318+
# Workaround for config injection if sagemaker_session is None, since in
319+
# that case sagemaker_session will not be initialized until
320+
# `_init_sagemaker_session_if_does_not_exist` is called later
321+
self._sagemaker_config = (
322+
load_sagemaker_config() if (self.sagemaker_session is None) else None
323+
)
324+
316325
self.role = resolve_value_from_config(
317326
role,
318327
MODEL_EXECUTION_ROLE_ARN_PATH,
319328
sagemaker_session=self.sagemaker_session,
329+
sagemaker_config=self._sagemaker_config,
320330
)
321331
self.vpc_config = resolve_value_from_config(
322-
vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
332+
vpc_config,
333+
MODEL_VPC_CONFIG_PATH,
334+
sagemaker_session=self.sagemaker_session,
335+
sagemaker_config=self._sagemaker_config,
323336
)
324337
self.endpoint_name = None
325338
self._is_compiled_model = False
@@ -330,16 +343,17 @@ def __init__(
330343
self._enable_network_isolation = resolve_value_from_config(
331344
enable_network_isolation,
332345
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
346+
default_value=False,
333347
sagemaker_session=self.sagemaker_session,
348+
sagemaker_config=self._sagemaker_config,
334349
)
335350
self.env = resolve_value_from_config(
336351
env,
337352
MODEL_PRIMARY_CONTAINER_ENVIRONMENT_PATH,
338353
default_value={},
339354
sagemaker_session=self.sagemaker_session,
355+
sagemaker_config=self._sagemaker_config,
340356
)
341-
if self._enable_network_isolation is None:
342-
self._enable_network_isolation = False
343357
self.model_kms_key = model_kms_key
344358
self.image_config = image_config
345359
self.entry_point = entry_point
@@ -542,9 +556,9 @@ def _init_sagemaker_session_if_does_not_exist(self, instance_type=None):
542556
return
543557

544558
if instance_type in ("local", "local_gpu"):
545-
self.sagemaker_session = local.LocalSession()
559+
self.sagemaker_session = local.LocalSession(sagemaker_config=self._sagemaker_config)
546560
else:
547-
self.sagemaker_session = session.Session()
561+
self.sagemaker_session = session.Session(sagemaker_config=self._sagemaker_config)
548562

549563
def prepare_container_def(
550564
self,

src/sagemaker/pipeline.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
MODEL_VPC_CONFIG_PATH,
2323
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
2424
MODEL_EXECUTION_ROLE_ARN_PATH,
25+
load_sagemaker_config,
2526
)
2627
from sagemaker.drift_check_baselines import DriftCheckBaselines
2728
from sagemaker.metadata_properties import MetadataProperties
@@ -90,19 +91,33 @@ def __init__(
9091
self.models = models
9192
self.predictor_cls = predictor_cls
9293
self.name = name
93-
self.sagemaker_session = sagemaker_session
9494
self.endpoint_name = None
95+
self.sagemaker_session = sagemaker_session
96+
97+
# In case, sagemaker_session is None, get sagemaker_config from load_sagemaker_config()
98+
# to resolve value from config for the respective parameter
99+
self._sagemaker_config = (
100+
load_sagemaker_config() if (self.sagemaker_session is None) else None
101+
)
102+
95103
self.role = resolve_value_from_config(
96-
role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
104+
role,
105+
MODEL_EXECUTION_ROLE_ARN_PATH,
106+
sagemaker_session=self.sagemaker_session,
107+
sagemaker_config=self._sagemaker_config,
97108
)
98109
self.vpc_config = resolve_value_from_config(
99-
vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
110+
vpc_config,
111+
MODEL_VPC_CONFIG_PATH,
112+
sagemaker_session=self.sagemaker_session,
113+
sagemaker_config=self._sagemaker_config,
100114
)
101115
self.enable_network_isolation = resolve_value_from_config(
102116
direct_input=enable_network_isolation,
103117
config_path=MODEL_ENABLE_NETWORK_ISOLATION_PATH,
104118
default_value=False,
105119
sagemaker_session=self.sagemaker_session,
120+
sagemaker_config=self._sagemaker_config,
106121
)
107122

108123
if not self.role:
@@ -214,7 +229,7 @@ def deploy(
214229
is not None. Otherwise, return None.
215230
"""
216231
if not self.sagemaker_session:
217-
self.sagemaker_session = Session()
232+
self.sagemaker_session = Session(sagemaker_config=self._sagemaker_config)
218233

219234
containers = self.pipeline_container_def(instance_type)
220235

@@ -303,7 +318,7 @@ def _create_sagemaker_pipeline_model(self, instance_type):
303318
support or not.
304319
"""
305320
if not self.sagemaker_session:
306-
self.sagemaker_session = Session()
321+
self.sagemaker_session = Session(sagemaker_config=self._sagemaker_config)
307322

308323
containers = self.pipeline_container_def(instance_type)
309324

src/sagemaker/utils.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,7 @@ def resolve_value_from_config(
10461046
config_path: str = None,
10471047
default_value=None,
10481048
sagemaker_session=None,
1049+
sagemaker_config: dict = None,
10491050
):
10501051
"""Decides which value for the caller to use.
10511052
@@ -1059,19 +1060,30 @@ def resolve_value_from_config(
10591060
10601061
Args:
10611062
direct_input: The value that the caller of this method starts with. Usually this is an
1062-
input to the caller's class or method.
1063+
input to the caller's class or method.
10631064
config_path (str): A string denoting the path used to lookup the value in the
1064-
sagemaker config.
1065+
sagemaker config.
10651066
default_value: The value used if not present elsewhere.
10661067
sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for
1067-
SageMaker interactions (default: None).
1068+
SageMaker interactions (default: None).
1069+
sagemaker_config (dict): The sdk defaults config that is normally accessed through a
1070+
Session object by doing `session.sagemaker_config`. (default: None) This parameter will
1071+
be checked for the config value if (and only if) sagemaker_session is None. This
1072+
parameter exists for the rare cases where the user provided no Session but a default
1073+
Session cannot be initialized before config injection is needed. In that case,
1074+
the config dictionary may be loaded and passed here before a default Session object
1075+
is created.
10681076
10691077
Returns:
10701078
The value that should be used by the caller
10711079
"""
10721080

10731081
config_value = (
1074-
get_sagemaker_config_value(sagemaker_session, config_path) if config_path else None
1082+
get_sagemaker_config_value(
1083+
sagemaker_session, config_path, sagemaker_config=sagemaker_config
1084+
)
1085+
if config_path
1086+
else None
10751087
)
10761088
_log_sagemaker_config_single_substitution(direct_input, config_value, config_path)
10771089

@@ -1084,23 +1096,33 @@ def resolve_value_from_config(
10841096
return default_value
10851097

10861098

1087-
def get_sagemaker_config_value(sagemaker_session, key):
1099+
def get_sagemaker_config_value(sagemaker_session, key, sagemaker_config: dict = None):
10881100
"""Returns the value that corresponds to the provided key from the configuration file.
10891101
10901102
Args:
10911103
key: Key Path of the config file entry.
10921104
sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for
1093-
SageMaker interactions.
1105+
SageMaker interactions.
1106+
sagemaker_config (dict): The sdk defaults config that is normally accessed through a
1107+
Session object by doing `session.sagemaker_config`. (default: None) This parameter will
1108+
be checked for the config value if (and only if) sagemaker_session is None. This
1109+
parameter exists for the rare cases where no Session provided but a default Session
1110+
cannot be initialized before config injection is needed. In that case, the config
1111+
dictionary may be loaded and passed here before a default Session object is created.
10941112
10951113
Returns:
10961114
object: The corresponding default value in the configuration file.
10971115
"""
1098-
if not sagemaker_session:
1116+
if sagemaker_session:
1117+
config_to_check = sagemaker_session.sagemaker_config
1118+
else:
1119+
config_to_check = sagemaker_config
1120+
1121+
if not config_to_check:
10991122
return None
11001123

1101-
if sagemaker_session.sagemaker_config:
1102-
validate_sagemaker_config(sagemaker_session.sagemaker_config)
1103-
config_value = get_config_value(key, sagemaker_session.sagemaker_config)
1124+
validate_sagemaker_config(config_to_check)
1125+
config_value = get_config_value(key, config_to_check)
11041126
# Copy the value so any modifications to the output will not modify the source config
11051127
return copy.deepcopy(config_value)
11061128

tests/unit/sagemaker/jumpstart/estimator/test_intelligent_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
metadata_inference_role = "th1234567iant role"
5858

5959

60-
def config_value_impl(sagemaker_session: Session, config_path: str):
60+
def config_value_impl(sagemaker_session: Session, config_path: str, sagemaker_config: dict):
6161
if config_path == TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH:
6262
return config_enable_network_isolation
6363

tests/unit/sagemaker/jumpstart/model/test_intelligent_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
metadata_enable_network_isolation = random.choice([True, False])
4343

4444

45-
def config_value_impl(sagemaker_session: Session, config_path: str):
45+
def config_value_impl(sagemaker_session: Session, config_path: str, sagemaker_config: dict):
4646
if config_path == MODEL_EXECUTION_ROLE_ARN_PATH:
4747
return config_role
4848

tests/unit/test_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from sagemaker.utils import (
3636
retry_with_backoff,
3737
check_and_get_run_experiment_config,
38+
get_sagemaker_config_value,
3839
resolve_value_from_config,
3940
resolve_class_attribute_from_config,
4041
resolve_nested_dict_value_from_config,
@@ -1190,6 +1191,10 @@ def test_resolve_value_from_config():
11901191
sagemaker_session.sagemaker_config.update(
11911192
{"SageMaker": {"EndpointConfig": {"KmsKeyId": "CONFIG_VALUE"}}}
11921193
)
1194+
sagemaker_config = {
1195+
"SchemaVersion": "1.0",
1196+
"SageMaker": {"EndpointConfig": {"KmsKeyId": "CONFIG_VALUE"}},
1197+
}
11931198

11941199
# direct_input should be respected
11951200
assert (
@@ -1223,6 +1228,13 @@ def test_resolve_value_from_config():
12231228

12241229
assert resolve_value_from_config(None, None, None, sagemaker_session) is None
12251230

1231+
# Config value from sagemaker_config should be returned
1232+
# if no direct_input and sagemaker_session is None
1233+
assert (
1234+
resolve_value_from_config(None, config_key_path, None, None, sagemaker_config)
1235+
== "CONFIG_VALUE"
1236+
)
1237+
12261238
# Different falsy direct_inputs
12271239
assert resolve_value_from_config("", config_key_path, None, sagemaker_session) == ""
12281240

@@ -1239,6 +1251,58 @@ def test_resolve_value_from_config():
12391251
mock_info_logger.reset_mock()
12401252

12411253

1254+
def test_get_sagemaker_config_value():
1255+
mock_config_logger = Mock()
1256+
1257+
mock_info_logger = Mock()
1258+
mock_config_logger.info = mock_info_logger
1259+
# using a shorter name for inside the test
1260+
sagemaker_session = MagicMock()
1261+
sagemaker_session.sagemaker_config = {"SchemaVersion": "1.0"}
1262+
config_key_path = "SageMaker.EndpointConfig.KmsKeyId"
1263+
sagemaker_session.sagemaker_config.update(
1264+
{"SageMaker": {"EndpointConfig": {"KmsKeyId": "CONFIG_VALUE"}}}
1265+
)
1266+
sagemaker_config = {
1267+
"SchemaVersion": "1.0",
1268+
"SageMaker": {"EndpointConfig": {"KmsKeyId": "CONFIG_VALUE"}},
1269+
}
1270+
1271+
# Tests that the function returns the correct value when the key exists in the sagemaker_session configuration.
1272+
assert (
1273+
get_sagemaker_config_value(
1274+
sagemaker_session=sagemaker_session, key=config_key_path, sagemaker_config=None
1275+
)
1276+
== "CONFIG_VALUE"
1277+
)
1278+
1279+
# Tests that the function correctly uses the sagemaker_config to get value for the requested
1280+
# config_key_path when sagemaker_session is None.
1281+
assert (
1282+
get_sagemaker_config_value(
1283+
sagemaker_session=None, key=config_key_path, sagemaker_config=sagemaker_config
1284+
)
1285+
== "CONFIG_VALUE"
1286+
)
1287+
1288+
# Tests that the function returns None when the key does not exist in the configuration.
1289+
invalid_key = "inavlid_key"
1290+
assert (
1291+
get_sagemaker_config_value(
1292+
sagemaker_session=sagemaker_session, key=invalid_key, sagemaker_config=sagemaker_config
1293+
)
1294+
is None
1295+
)
1296+
1297+
# Tests that the function returns None when sagemaker_session and sagemaker_config are None.
1298+
assert (
1299+
get_sagemaker_config_value(
1300+
sagemaker_session=None, key=config_key_path, sagemaker_config=None
1301+
)
1302+
is None
1303+
)
1304+
1305+
12421306
@patch("jsonschema.validate")
12431307
@pytest.mark.parametrize(
12441308
"existing_value, config_value, default_value",

0 commit comments

Comments
 (0)