1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313
14- """This module contains code related to Amazon SageMaker Collection .
14+ """This module contains code related to Amazon SageMaker Collections in the Model Registry .
1515
16- These Classes helps in providing features to maintain and create collections
16+ Use these methods to help you create and maintain your Collections.
1717"""
1818
1919from __future__ import absolute_import
2727
2828
2929class Collection (object ):
30- """Sets up Amazon SageMaker Collection."""
30+ """Sets up an Amazon SageMaker Collection."""
3131
3232 def __init__ (self , sagemaker_session ):
3333 """Initializes a Collection instance.
3434
35- The collection provides a logical grouping for model groups
35+ A Collection is a logical grouping of Model Groups.
3636
3737 Args:
38- sagemaker_session (sagemaker.session.Session): Session object which
39- manages interactions with Amazon SageMaker APIs and any other
40- AWS services needed. If not specified, one is created using
38+ sagemaker_session (sagemaker.session.Session): A Session object which
39+ manages interactions between Amazon SageMaker APIs and other
40+ AWS services needed. If unspecified, a session is created using
4141 the default AWS configuration chain.
4242 """
43+
4344 self .sagemaker_session = sagemaker_session or Session ()
4445
4546 def _check_access_error (self , err : ClientError ):
46- """To check if the error is related to the access error and to provide the relavant message
47+ """Checks if the error is related to the access error and provide the relevant message.
4748
4849 Args:
49- err: The client error that needs to be checked
50+ err: The client error to check.
5051 """
52+
5153 error_code = err .response ["Error" ]["Code" ]
5254 if error_code == "AccessDeniedException" :
5355 raise Exception (
@@ -57,12 +59,12 @@ def _check_access_error(self, err: ClientError):
5759 )
5860
5961 def _add_model_group (self , model_package_group , tag_rule_key , tag_rule_value ):
60- """To add a model package group to a collection
62+ """Adds a Model Group to a Collection.
6163
6264 Args:
63- model_package_group (str): The name of the model package group
64- tag_rule_key (str): The tag key of the corresponing collection to be added into
65- tag_rule_value (str): The tag value of the corresponing collection to be added into
65+ model_package_group (str): The name of the Model Group.
66+ tag_rule_key (str): The tag key of the destination collection.
67+ tag_rule_value (str): The tag value of the destination collection.
6668 """
6769 model_group_details = self .sagemaker_session .sagemaker_client .describe_model_package_group (
6870 ModelPackageGroupName = model_package_group
@@ -78,11 +80,11 @@ def _add_model_group(self, model_package_group, tag_rule_key, tag_rule_value):
7880 )
7981
8082 def _remove_model_group (self , model_package_group , tag_rule_key ):
81- """To remove a model package group from a collection
83+ """Removes a Model Group from a Collection.
8284
8385 Args:
84- model_package_group (str): The name of the model package group
85- tag_rule_key (str): The tag key of the corresponing collection to be removed from
86+ model_package_group (str): The name of the Model Group
87+ tag_rule_key (str): The tag key of the Collection from which to remove the Model Group.
8688 """
8789 model_group_details = self .sagemaker_session .sagemaker_client .describe_model_package_group (
8890 ModelPackageGroupName = model_package_group
@@ -92,12 +94,12 @@ def _remove_model_group(self, model_package_group, tag_rule_key):
9294 )
9395
9496 def create (self , collection_name : str , parent_collection_name : str = None ):
95- """Creates a collection
97+ """Creates a Collection.
9698
9799 Args:
98- collection_name (str): The name of the collection to be created
99- parent_collection_name (str): The name of the parent collection .
100- To be None if the collection is to be created on the root level
100+ collection_name (str): The name of the Collection to create.
101+ parent_collection_name (str): The name of the parent Collection .
102+ Is `` None`` if the Collection is created at the root level.
101103 """
102104
103105 tag_rule_key = f"sagemaker:collection-path:{ int (time .time () * 1000 )} "
@@ -151,11 +153,11 @@ def create(self, collection_name: str, parent_collection_name: str = None):
151153 raise
152154
153155 def delete (self , collections : List [str ]):
154- """Deletes a list of collection .
156+ """Deletes a list of Collections .
155157
156158 Args:
157- collections (List[str]): List of collections to be deleted
158- Only deletes a collection if it is empty
159+ collections (List[str]): A list of Collections to delete.
160+ Only deletes a Collection if it is empty.
159161 """
160162
161163 if len (collections ) > 10 :
@@ -201,7 +203,7 @@ def delete(self, collections: List[str]):
201203 }
202204
203205 def _get_collection_tag_rule (self , collection_name : str ):
204- """Returns the tag rule key and value for a collection """
206+ """Returns the tag rule key and value for a Collection. """
205207
206208 if collection_name is not None :
207209 try :
@@ -230,11 +232,11 @@ def _get_collection_tag_rule(self, collection_name: str):
230232 raise ValueError ("Collection name is required" )
231233
232234 def add_model_groups (self , collection_name : str , model_groups : List [str ]):
233- """To add list of model package groups to a collection
235+ """Adds a list of Model Groups to a Collection.
234236
235237 Args:
236- collection_name (str): The name of the collection
237- model_groups List[str]: Model pckage group names list to be added into the collection
238+ collection_name (str): The name of the Collection.
239+ model_groups ( List[str]): The names of the Model Groups to add to the Collection.
238240 """
239241 if len (model_groups ) > 10 :
240242 raise Exception ("Model groups can have a maximum length of 10" )
@@ -268,11 +270,11 @@ def add_model_groups(self, collection_name: str, model_groups: List[str]):
268270 }
269271
270272 def remove_model_groups (self , collection_name : str , model_groups : List [str ]):
271- """To remove list of model package groups from a collection
273+ """Removes a list of Model Groups from a Collection.
272274
273275 Args:
274- collection_name (str): The name of the collection
275- model_groups List[str]: Model package group names list to be removed
276+ collection_name (str): The name of the Collection.
277+ model_groups ( List[str]): The names of the Model Groups to remove.
276278 """
277279
278280 if len (model_groups ) > 10 :
@@ -309,12 +311,12 @@ def remove_model_groups(self, collection_name: str, model_groups: List[str]):
309311 def move_model_group (
310312 self , source_collection_name : str , model_group : str , destination_collection_name : str
311313 ):
312- """To move a model package group from one collection to another
314+ """Moves a Model Group from one Collection to another.
313315
314316 Args:
315- source_collection_name (str): Collection name of the source
316- model_group (str): Model package group names which is to be moved
317- destination_collection_name (str): Collection name of the destination
317+ source_collection_name (str): The name of the source Collection.
318+ model_group (str): The name of the Model Group to move.
319+ destination_collection_name (str): The name of the destination Collection.
318320 """
319321 remove_details = self .remove_model_groups (
320322 collection_name = source_collection_name , model_groups = [model_group ]
@@ -327,7 +329,7 @@ def move_model_group(
327329 )
328330
329331 if len (added_details ["failure" ]) == 1 :
330- # adding the model group back to the source collection in case of an add failure
332+ # adding the Model Group back to the source collection in case of an add failure
331333 self .add_model_groups (
332334 collection_name = source_collection_name , model_groups = [model_group ]
333335 )
@@ -338,10 +340,10 @@ def move_model_group(
338340 }
339341
340342 def _convert_tag_collection_response (self , tag_collections : List [str ]):
341- """Converts collection response from tag api to collection list response
343+ """Converts a Collection response from the tag api to a Collection list response.
342344
343345 Args:
344- tag_collections List[dict]: Collections list response from tag api
346+ tag_collections List[dict]: The Collection list response from the tag api.
345347 """
346348 collection_details = []
347349 for collection in tag_collections :
@@ -359,11 +361,12 @@ def _convert_tag_collection_response(self, tag_collections: List[str]):
359361 def _convert_group_resource_response (
360362 self , group_resource_details : List [dict ], is_model_group : bool = False
361363 ):
362- """Converts collection response from resource group api to collection list response
364+ """Converts a Collection response from the resource group api to a Collection list response.
363365
364366 Args:
365- group_resource_details (List[dict]): Collections list response from resource group api
366- is_model_group (bool): If the reponse is of collection or model group type
367+ group_resource_details (List[dict]): The Collection list response from the
368+ resource group api.
369+ is_model_group (bool): Indicates if the response is of Collection or Model Group type.
367370 """
368371 collection_details = []
369372 if group_resource_details ["Resources" ]:
@@ -382,12 +385,11 @@ def _convert_group_resource_response(
382385 return collection_details
383386
384387 def _get_full_list_resource (self , collection_name , collection_filter ):
385- """Iterating to the full resource group list and returns appended paginated response
388+ """Iterates the full resource group list and returns the appended paginated response.
386389
387390 Args:
388- collection_name (str): Name of the collection to get the details
389- collection_filter (dict): Filter details to be passed to get the resource list
390-
391+ collection_name (str): The name of the Collection from which to get details.
392+ collection_filter (dict): Filter details to pass to get the resource list.
391393 """
392394 list_group_response = self .sagemaker_session .list_group_resources (
393395 group = collection_name , filters = collection_filter
@@ -412,12 +414,13 @@ def _get_full_list_resource(self, collection_name, collection_filter):
412414 return list_group_response
413415
414416 def list_collection (self , collection_name : str = None ):
415- """To all list the collections and content of the collections
417+ """Lists the contents of the specified Collection.
416418
417- In case there is no collection_name, it lists all the collections on the root level
419+ If there is no Collection with the name ``collection_name``, lists all the
420+ Collections at the root level.
418421
419422 Args:
420- collection_name (str): The name of the collection to list the contents of
423+ collection_name (str): The name of the Collection whose contents are listed.
421424 """
422425 collection_content = []
423426 if collection_name is None :
0 commit comments