Skip to content

Commit 62184c1

Browse files
committed
fix formatting
1 parent 622f706 commit 62184c1

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

src/sagemaker/jumpstart/hub/interfaces.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
670670
)
671671

672672
if self.training_supported:
673-
self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri")
673+
self.default_training_dataset_uri: Optional[str] = json_obj.get(
674+
"DefaultTrainingDatasetUri"
675+
)
674676
self.training_model_package_artifact_uri: Optional[str] = json_obj.get(
675677
"TrainingModelPackageArtifactUri"
676678
)

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,12 @@ def get_hub_model_version(
220220
sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION
221221

222222
try:
223-
hub_content_summaries = _list_hub_content_versions_helper(hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type, sagemaker_session=sagemaker_session)
223+
hub_content_summaries = _list_hub_content_versions_helper(
224+
hub_name=hub_name,
225+
hub_content_name=hub_model_name,
226+
hub_content_type=hub_model_type,
227+
sagemaker_session=sagemaker_session,
228+
)
224229
except Exception as ex:
225230
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
226231

@@ -236,19 +241,28 @@ def get_hub_model_version(
236241
return marketplace_hub_content_version
237242
raise
238243

239-
def _list_hub_content_versions_helper(hub_name, hub_content_name, hub_content_type, sagemaker_session):
244+
245+
def _list_hub_content_versions_helper(
246+
hub_name, hub_content_name, hub_content_type, sagemaker_session
247+
):
240248
all_hub_content_summaries = []
241249
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
242250
hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type
243251
)
244252
all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries"))
245253
while "NextToken" in list_hub_content_versions_response:
246254
list_hub_content_versions_response = sagemaker_session.list_hub_content_versions(
247-
hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type, next_token=list_hub_content_versions_response["NextToken"]
248-
)
249-
all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries"))
255+
hub_name=hub_name,
256+
hub_content_name=hub_content_name,
257+
hub_content_type=hub_content_type,
258+
next_token=list_hub_content_versions_response["NextToken"],
259+
)
260+
all_hub_content_summaries.extend(
261+
list_hub_content_versions_response.get("HubContentSummaries")
262+
)
250263
return all_hub_content_summaries
251264

265+
252266
def _get_hub_model_version_for_open_weight_version(
253267
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
254268
) -> str:

0 commit comments

Comments
 (0)