Skip to content

Commit d7175bb

Browse files
committed
fixed formatting
1 parent 0aeca78 commit d7175bb

File tree

3 files changed

+26
-16
lines changed

3 files changed

+26
-16
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,14 @@ def _add_hub_access_config_to_kwargs_inputs(
315315
dataset_uri = kwargs.specs.default_training_dataset_uri
316316
if isinstance(kwargs.inputs, str):
317317
if dataset_uri is not None and dataset_uri == kwargs.inputs:
318-
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
318+
kwargs.inputs = TrainingInput(
319+
s3_data=kwargs.inputs, hub_access_config=hub_access_config
320+
)
319321
elif isinstance(kwargs.inputs, TrainingInput):
320-
if dataset_uri is not None and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]:
322+
if (
323+
dataset_uri is not None
324+
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
325+
):
321326
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
322327
elif isinstance(kwargs.inputs, dict):
323328
for k, v in kwargs.inputs.items():
@@ -327,7 +332,10 @@ def _add_hub_access_config_to_kwargs_inputs(
327332
training_input.add_hub_access_config(hub_access_config=hub_access_config)
328333
kwargs.inputs[k] = training_input
329334
elif isinstance(kwargs.inputs, TrainingInput):
330-
if dataset_uri is not None and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]:
335+
if (
336+
dataset_uri is not None
337+
and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"]
338+
):
331339
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
332340

333341
return kwargs

src/sagemaker/jumpstart/types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,8 +1464,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
14641464
else None
14651465
)
14661466
self.model_subscription_link = json_obj.get("model_subscription_link")
1467-
self.default_training_dataset_key: Optional[str] = json_obj.get("default_training_dataset_key")
1468-
self.default_training_dataset_uri: Optional[str] = json_obj.get("default_training_dataset_uri")
1467+
self.default_training_dataset_key: Optional[str] = json_obj.get(
1468+
"default_training_dataset_key"
1469+
)
1470+
self.default_training_dataset_uri: Optional[str] = json_obj.get(
1471+
"default_training_dataset_uri"
1472+
)
14691473

14701474
def to_json(self) -> Dict[str, Any]:
14711475
"""Returns json representation of JumpStartMetadataBaseFields object."""

tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
)
3737
from tests.integ.sagemaker.jumpstart.utils import (
3838
get_sm_session,
39-
get_training_dataset_for_model_and_version
39+
get_training_dataset_for_model_and_version,
4040
)
4141

4242
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
@@ -81,13 +81,13 @@ def test_jumpstart_hub_estimator(setup, add_model_references):
8181
)
8282

8383
estimator.fit(
84-
inputs = {
84+
inputs={
8585
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
8686
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
8787
}
8888
)
8989

90-
# test that we can create a JumpStartEstimator from existing job with `attach`
90+
# test that we can create a JumpStartEstimator from existing job with `attach`
9191
estimator = JumpStartEstimator.attach(
9292
training_job_name=estimator.latest_training_job.name,
9393
model_id=model_id,
@@ -121,14 +121,13 @@ def test_jumpstart_hub_estimator_with_default_session(setup, add_model_reference
121121
)
122122

123123
estimator.fit(
124-
inputs = {
124+
inputs={
125125
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
126126
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
127127
}
128128
)
129129

130-
131-
# test that we can create a JumpStartEstimator from existing job with `attach`
130+
# test that we can create a JumpStartEstimator from existing job with `attach`
132131
estimator = JumpStartEstimator.attach(
133132
training_job_name=estimator.latest_training_job.name,
134133
model_id=model_id,
@@ -138,7 +137,7 @@ def test_jumpstart_hub_estimator_with_default_session(setup, add_model_reference
138137
# uses ml.p3.2xlarge instance
139138
predictor = estimator.deploy(
140139
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
141-
role=get_sm_session().get_caller_identity_arn()
140+
role=get_sm_session().get_caller_identity_arn(),
142141
)
143142

144143
response = predictor.predict(["hello", "world"])
@@ -159,10 +158,10 @@ def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references):
159158

160159
estimator.fit(
161160
accept_eula=True,
162-
inputs = {
161+
inputs={
163162
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
164163
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
165-
}
164+
},
166165
)
167166

168167
estimator = JumpStartEstimator.attach(
@@ -196,14 +195,13 @@ def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references)
196195
)
197196
with pytest.raises(Exception):
198197
estimator.fit(
199-
inputs = {
198+
inputs={
200199
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
201200
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
202201
}
203202
)
204203

205204

206-
207205
def test_instantiating_estimator(setup, add_model_references):
208206

209207
model_id = "catboost-regression-model"

0 commit comments

Comments
 (0)