From 7c09242f37ead436dec28a3cae93c4e45d43dec1 Mon Sep 17 00:00:00 2001 From: Keshav Chandak Date: Fri, 7 Feb 2025 11:46:12 +0530 Subject: [PATCH] Fixed pagination failing while listing collections --- src/sagemaker/session.py | 2 +- tests/integ/test_collection.py | 286 +++++++++++++++++---------------- 2 files changed, 150 insertions(+), 138 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 04a7326557..c6a2014ae5 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -5286,7 +5286,7 @@ def get_tagging_resources(self, tag_filters, resource_type_filters): resource_tag_response = self.resource_group_tagging_client.get_resources( TagFilters=tag_filters, ResourceTypeFilters=resource_type_filters, - NextToken=next_token, + PaginationToken=next_token, ) resource_list = resource_list + resource_tag_response["ResourceTagMappingList"] next_token = resource_tag_response.get("PaginationToken") diff --git a/tests/integ/test_collection.py b/tests/integ/test_collection.py index 2ee1d90e34..9a6db645cf 100644 --- a/tests/integ/test_collection.py +++ b/tests/integ/test_collection.py @@ -19,20 +19,22 @@ def test_create_collection_root_success(sagemaker_session): collection = Collection(sagemaker_session) collection_name = unique_name_from_base("test-collection") - collection.create(collection_name) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - assert collection_details["ResponseMetadata"]["HTTPStatusCode"] == 200 - delete_response = collection.delete([collection_name]) - assert len(delete_response["deleted_collections"]) == 1 - assert len(delete_response["delete_collection_failures"]) == 0 + try: + collection.create(collection_name) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + assert collection_details["ResponseMetadata"]["HTTPStatusCode"] == 200 + finally: + delete_response = collection.delete([collection_name]) + assert len(delete_response["deleted_collections"]) == 1 + assert len(delete_response["delete_collection_failures"]) == 0 def test_create_collection_nested_success(sagemaker_session): @@ -41,25 +43,27 @@ def test_create_collection_nested_success(sagemaker_session): child_collection_name = unique_name_from_base("test-collection-2") collection.create(collection_name) collection.create(collection_name=child_collection_name, parent_collection_name=collection_name) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - # has one child i.e child collection - assert len(collection_details["Resources"]) == 1 - - collection_details = sagemaker_session.list_group_resources( - group=child_collection_name, filters=collection_filter - ) - collection_details["ResponseMetadata"]["HTTPStatusCode"] - delete_response = collection.delete([child_collection_name, collection_name]) - assert len(delete_response["deleted_collections"]) == 2 - assert len(delete_response["delete_collection_failures"]) == 0 + try: + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + # has one child i.e child collection + assert len(collection_details["Resources"]) == 1 + + collection_details = sagemaker_session.list_group_resources( + group=child_collection_name, filters=collection_filter + ) + collection_details["ResponseMetadata"]["HTTPStatusCode"] + finally: + delete_response = collection.delete([child_collection_name, collection_name]) + assert len(delete_response["deleted_collections"]) == 2 + assert len(delete_response["delete_collection_failures"]) == 0 def test_add_remove_model_groups_in_collection_success(sagemaker_session): @@ -70,40 +74,42 @@ def test_add_remove_model_groups_in_collection_success(sagemaker_session): collection = Collection(sagemaker_session) collection_name = unique_name_from_base("test-collection") collection.create(collection_name) - model_groups = [] - model_groups.append(model_group_name) - add_response = collection.add_model_groups( - collection_name=collection_name, model_groups=model_groups - ) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - - assert len(add_response["failure"]) == 0 - assert len(add_response["added_groups"]) == 1 - assert len(collection_details["Resources"]) == 1 - - remove_response = collection.remove_model_groups( - collection_name=collection_name, model_groups=model_groups - ) - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - assert len(remove_response["failure"]) == 0 - assert len(remove_response["removed_groups"]) == 1 - assert len(collection_details["Resources"]) == 0 - - delete_response = collection.delete([collection_name]) - assert len(delete_response["deleted_collections"]) == 1 - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_group_name - ) + try: + model_groups = [] + model_groups.append(model_group_name) + add_response = collection.add_model_groups( + collection_name=collection_name, model_groups=model_groups + ) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + + assert len(add_response["failure"]) == 0 + assert len(add_response["added_groups"]) == 1 + assert len(collection_details["Resources"]) == 1 + + remove_response = collection.remove_model_groups( + collection_name=collection_name, model_groups=model_groups + ) + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + assert len(remove_response["failure"]) == 0 + assert len(remove_response["removed_groups"]) == 1 + assert len(collection_details["Resources"]) == 0 + + finally: + delete_response = collection.delete([collection_name]) + assert len(delete_response["deleted_collections"]) == 1 + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) def test_move_model_groups_in_collection_success(sagemaker_session): @@ -116,56 +122,58 @@ def test_move_model_groups_in_collection_success(sagemaker_session): destination_collection_name = unique_name_from_base("test-collection-destination") collection.create(source_collection_name) collection.create(destination_collection_name) - model_groups = [] - model_groups.append(model_group_name) - add_response = collection.add_model_groups( - collection_name=source_collection_name, model_groups=model_groups - ) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=source_collection_name, filters=collection_filter - ) - - assert len(add_response["failure"]) == 0 - assert len(add_response["added_groups"]) == 1 - assert len(collection_details["Resources"]) == 1 - - move_response = collection.move_model_group( - source_collection_name=source_collection_name, - model_group=model_group_name, - destination_collection_name=destination_collection_name, - ) - - assert move_response["moved_success"] == model_group_name - - collection_details = sagemaker_session.list_group_resources( - group=destination_collection_name, filters=collection_filter - ) - - assert len(collection_details["Resources"]) == 1 - - collection_details = sagemaker_session.list_group_resources( - group=source_collection_name, filters=collection_filter - ) - assert len(collection_details["Resources"]) == 0 - - remove_response = collection.remove_model_groups( - collection_name=destination_collection_name, model_groups=model_groups - ) - - assert len(remove_response["failure"]) == 0 - assert len(remove_response["removed_groups"]) == 1 - - delete_response = collection.delete([source_collection_name, destination_collection_name]) - assert len(delete_response["deleted_collections"]) == 2 - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_group_name - ) + try: + model_groups = [] + model_groups.append(model_group_name) + add_response = collection.add_model_groups( + collection_name=source_collection_name, model_groups=model_groups + ) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=source_collection_name, filters=collection_filter + ) + + assert len(add_response["failure"]) == 0 + assert len(add_response["added_groups"]) == 1 + assert len(collection_details["Resources"]) == 1 + + move_response = collection.move_model_group( + source_collection_name=source_collection_name, + model_group=model_group_name, + destination_collection_name=destination_collection_name, + ) + + assert move_response["moved_success"] == model_group_name + + collection_details = sagemaker_session.list_group_resources( + group=destination_collection_name, filters=collection_filter + ) + + assert len(collection_details["Resources"]) == 1 + + collection_details = sagemaker_session.list_group_resources( + group=source_collection_name, filters=collection_filter + ) + assert len(collection_details["Resources"]) == 0 + + remove_response = collection.remove_model_groups( + collection_name=destination_collection_name, model_groups=model_groups + ) + + assert len(remove_response["failure"]) == 0 + assert len(remove_response["removed_groups"]) == 1 + + finally: + delete_response = collection.delete([source_collection_name, destination_collection_name]) + assert len(delete_response["deleted_collections"]) == 2 + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) def test_list_collection_success(sagemaker_session): @@ -176,23 +184,27 @@ def test_list_collection_success(sagemaker_session): collection = Collection(sagemaker_session) collection_name = unique_name_from_base("test-collection") collection.create(collection_name) - model_groups = [] - model_groups.append(model_group_name) - collection.add_model_groups(collection_name=collection_name, model_groups=model_groups) - child_collection_name = unique_name_from_base("test-collection") - collection.create(parent_collection_name=collection_name, collection_name=child_collection_name) - root_collections = collection.list_collection() - is_collection_found = False - for root_collection in root_collections: - if root_collection["Name"] == collection_name: - is_collection_found = True - assert is_collection_found - - collection_content = collection.list_collection(collection_name) - assert len(collection_content) == 2 - - collection.remove_model_groups(collection_name=collection_name, model_groups=model_groups) - collection.delete([child_collection_name, collection_name]) - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_group_name - ) + try: + model_groups = [] + model_groups.append(model_group_name) + collection.add_model_groups(collection_name=collection_name, model_groups=model_groups) + child_collection_name = unique_name_from_base("test-collection") + collection.create( + parent_collection_name=collection_name, collection_name=child_collection_name + ) + root_collections = collection.list_collection() + is_collection_found = False + for root_collection in root_collections: + if root_collection["Name"] == collection_name: + is_collection_found = True + assert is_collection_found + + collection_content = collection.list_collection(collection_name) + assert len(collection_content) == 2 + + collection.remove_model_groups(collection_name=collection_name, model_groups=model_groups) + finally: + collection.delete([child_collection_name, collection_name]) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + )