@@ -56,6 +56,41 @@ def _check_access_error(self, err: ClientError):
5656 "https://docs.aws.amazon.com/sagemaker/latest/dg/modelcollections-permissions.html"
5757 )
5858
59+ def _add_model_group (self , model_package_group , tag_rule_key , tag_rule_value ):
60+ """To add a model package group to a collection
61+
62+ Args:
63+ model_package_group (str): The name of the model package group
64+ tag_rule_key (str): The tag key of the corresponing collection to be added into
65+ tag_rule_value (str): The tag value of the corresponing collection to be added into
66+ """
67+ model_group_details = self .sagemaker_session .sagemaker_client .describe_model_package_group (
68+ ModelPackageGroupName = model_package_group
69+ )
70+ self .sagemaker_session .sagemaker_client .add_tags (
71+ ResourceArn = model_group_details ["ModelPackageGroupArn" ],
72+ Tags = [
73+ {
74+ "Key" : tag_rule_key ,
75+ "Value" : tag_rule_value ,
76+ }
77+ ],
78+ )
79+
80+ def _remove_model_group (self , model_package_group , tag_rule_key ):
81+ """To remove a model package group from a collection
82+
83+ Args:
84+ model_package_group (str): The name of the model package group
85+ tag_rule_key (str): The tag key of the corresponing collection to be removed from
86+ """
87+ model_group_details = self .sagemaker_session .sagemaker_client .describe_model_package_group (
88+ ModelPackageGroupName = model_package_group
89+ )
90+ self .sagemaker_session .sagemaker_client .delete_tags (
91+ ResourceArn = model_group_details ["ModelPackageGroupArn" ], TagKeys = [tag_rule_key ]
92+ )
93+
5994 def create (self , collection_name : str , parent_collection_name : str = None ):
6095 """Creates a collection
6196
@@ -65,38 +100,22 @@ def create(self, collection_name: str, parent_collection_name: str = None):
65100 To be None if the collection is to be created on the root level
66101 """
67102
68- tag_rule_key = f"sagemaker:collection-path:{ time .time ()} "
103+ tag_rule_key = f"sagemaker:collection-path:{ int ( time .time () * 1000 )} "
69104 tags_on_collection = {
70105 "sagemaker:collection" : "true" ,
71106 "sagemaker:collection-path:root" : "true" ,
72107 }
73108 tag_rule_values = [collection_name ]
74109
75110 if parent_collection_name is not None :
76- try :
77- group_query = self .sagemaker_session .get_resource_group_query (
78- group = parent_collection_name
79- )
80- except ClientError as e :
81- error_code = e .response ["Error" ]["Code" ]
82-
83- if error_code == "NotFoundException" :
84- raise ValueError (f"Cannot find collection: { parent_collection_name } " )
85- self ._check_access_error (err = e )
86- raise
87- if group_query .get ("GroupQuery" ):
88- parent_tag_rule_query = json .loads (
89- group_query ["GroupQuery" ].get ("ResourceQuery" , {}).get ("Query" , "" )
90- )
91- parent_tag_rule = parent_tag_rule_query .get ("TagFilters" , [])[0 ]
92- if not parent_tag_rule :
93- raise "Invalid parent_collection_name"
94- parent_tag_value = parent_tag_rule ["Values" ][0 ]
95- tags_on_collection = {
96- parent_tag_rule ["Key" ]: parent_tag_value ,
97- "sagemaker:collection" : "true" ,
98- }
99- tag_rule_values = [f"{ parent_tag_value } /{ collection_name } " ]
111+ parent_tag_rules = self ._get_collection_tag_rule (collection_name = parent_collection_name )
112+ parent_tag_rule_key = parent_tag_rules ["tag_rule_key" ]
113+ parent_tag_value = parent_tag_rules ["tag_rule_value" ]
114+ tags_on_collection = {
115+ parent_tag_rule_key : parent_tag_value ,
116+ "sagemaker:collection" : "true" ,
117+ }
118+ tag_rule_values = [f"{ parent_tag_value } /{ collection_name } " ]
100119 try :
101120 resource_filters = [
102121 "AWS::SageMaker::ModelPackageGroup" ,
@@ -122,19 +141,17 @@ def create(self, collection_name: str, parent_collection_name: str = None):
122141 "Name" : collection_create_response ["Group" ]["Name" ],
123142 "Arn" : collection_create_response ["Group" ]["GroupArn" ],
124143 }
125-
126144 except ClientError as e :
127145 message = e .response ["Error" ]["Message" ]
128146 error_code = e .response ["Error" ]["Code" ]
129147
130148 if error_code == "BadRequestException" and "group already exists" in message :
131149 raise ValueError ("Collection with the given name already exists" )
132-
133150 self ._check_access_error (err = e )
134151 raise
135152
136153 def delete (self , collections : List [str ]):
137- """Deletes a lits of collection
154+ """Deletes a list of collection.
138155
139156 Args:
140157 collections (List[str]): List of collections to be deleted
@@ -152,6 +169,8 @@ def delete(self, collections: List[str]):
152169 "Values" : ["AWS::ResourceGroups::Group" , "AWS::SageMaker::ModelPackageGroup" ],
153170 },
154171 ]
172+
173+ # loops over the list of collection and deletes one at a time.
155174 for collection in collections :
156175 try :
157176 collection_details = self .sagemaker_session .list_group_resources (
@@ -180,3 +199,264 @@ def delete(self, collections: List[str]):
180199 "deleted_collections" : deleted_collection ,
181200 "delete_collection_failures" : delete_collection_failures ,
182201 }
202+
203+ def _get_collection_tag_rule (self , collection_name : str ):
204+ """Returns the tag rule key and value for a collection"""
205+
206+ if collection_name is not None :
207+ try :
208+ group_query = self .sagemaker_session .get_resource_group_query (group = collection_name )
209+ except ClientError as e :
210+ error_code = e .response ["Error" ]["Code" ]
211+
212+ if error_code == "NotFoundException" :
213+ raise ValueError (f"Cannot find collection: { collection_name } " )
214+ self ._check_access_error (err = e )
215+ raise
216+ if group_query .get ("GroupQuery" ):
217+ tag_rule_query = json .loads (
218+ group_query ["GroupQuery" ].get ("ResourceQuery" , {}).get ("Query" , "" )
219+ )
220+ tag_rule = tag_rule_query .get ("TagFilters" , [])[0 ]
221+ if not tag_rule :
222+ raise "Unsupported parent_collection_name"
223+ tag_rule_value = tag_rule ["Values" ][0 ]
224+ tag_rule_key = tag_rule ["Key" ]
225+
226+ return {
227+ "tag_rule_key" : tag_rule_key ,
228+ "tag_rule_value" : tag_rule_value ,
229+ }
230+ raise ValueError ("Collection name is required" )
231+
232+ def add_model_groups (self , collection_name : str , model_groups : List [str ]):
233+ """To add list of model package groups to a collection
234+
235+ Args:
236+ collection_name (str): The name of the collection
237+ model_groups List[str]: Model pckage group names list to be added into the collection
238+ """
239+ if len (model_groups ) > 10 :
240+ raise Exception ("Model groups can have a maximum length of 10" )
241+ tag_rules = self ._get_collection_tag_rule (collection_name = collection_name )
242+ tag_rule_key = tag_rules ["tag_rule_key" ]
243+ tag_rule_value = tag_rules ["tag_rule_value" ]
244+
245+ add_groups_success = []
246+ add_groups_failure = []
247+ if tag_rule_key is not None and tag_rule_value is not None :
248+ for model_group in model_groups :
249+ try :
250+ self ._add_model_group (
251+ model_package_group = model_group ,
252+ tag_rule_key = tag_rule_key ,
253+ tag_rule_value = tag_rule_value ,
254+ )
255+ add_groups_success .append (model_group )
256+ except ClientError as e :
257+ self ._check_access_error (err = e )
258+ message = e .response ["Error" ]["Message" ]
259+ add_groups_failure .append (
260+ {
261+ "model_group" : model_group ,
262+ "failure_reason" : message ,
263+ }
264+ )
265+ return {
266+ "added_groups" : add_groups_success ,
267+ "failure" : add_groups_failure ,
268+ }
269+
270+ def remove_model_groups (self , collection_name : str , model_groups : List [str ]):
271+ """To remove list of model package groups from a collection
272+
273+ Args:
274+ collection_name (str): The name of the collection
275+ model_groups List[str]: Model package group names list to be removed
276+ """
277+
278+ if len (model_groups ) > 10 :
279+ raise Exception ("Model groups can have a maximum length of 10" )
280+ tag_rules = self ._get_collection_tag_rule (collection_name = collection_name )
281+
282+ tag_rule_key = tag_rules ["tag_rule_key" ]
283+ tag_rule_value = tag_rules ["tag_rule_value" ]
284+
285+ remove_groups_success = []
286+ remove_groups_failure = []
287+ if tag_rule_key is not None and tag_rule_value is not None :
288+ for model_group in model_groups :
289+ try :
290+ self ._remove_model_group (
291+ model_package_group = model_group ,
292+ tag_rule_key = tag_rule_key ,
293+ )
294+ remove_groups_success .append (model_group )
295+ except ClientError as e :
296+ self ._check_access_error (err = e )
297+ message = e .response ["Error" ]["Message" ]
298+ remove_groups_failure .append (
299+ {
300+ "model_group" : model_group ,
301+ "failure_reason" : message ,
302+ }
303+ )
304+ return {
305+ "removed_groups" : remove_groups_success ,
306+ "failure" : remove_groups_failure ,
307+ }
308+
309+ def move_model_group (
310+ self , source_collection_name : str , model_group : str , destination_collection_name : str
311+ ):
312+ """To move a model package group from one collection to another
313+
314+ Args:
315+ source_collection_name (str): Collection name of the source
316+ model_group (str): Model package group names which is to be moved
317+ destination_collection_name (str): Collection name of the destination
318+ """
319+ remove_details = self .remove_model_groups (
320+ collection_name = source_collection_name , model_groups = [model_group ]
321+ )
322+ if len (remove_details ["failure" ]) == 1 :
323+ raise Exception (remove_details ["failure" ][0 ]["failure" ])
324+
325+ added_details = self .add_model_groups (
326+ collection_name = destination_collection_name , model_groups = [model_group ]
327+ )
328+
329+ if len (added_details ["failure" ]) == 1 :
330+ # adding the model group back to the source collection in case of an add failure
331+ self .add_model_groups (
332+ collection_name = source_collection_name , model_groups = [model_group ]
333+ )
334+ raise Exception (added_details ["failure" ][0 ]["failure" ])
335+
336+ return {
337+ "moved_success" : model_group ,
338+ }
339+
340+ def _convert_tag_collection_response (self , tag_collections : List [str ]):
341+ """Converts collection response from tag api to collection list response
342+
343+ Args:
344+ tag_collections List[dict]: Collections list response from tag api
345+ """
346+ collection_details = []
347+ for collection in tag_collections :
348+ collection_arn = collection ["ResourceARN" ]
349+ collection_name = collection_arn .split ("group/" )[1 ]
350+ collection_details .append (
351+ {
352+ "Name" : collection_name ,
353+ "Arn" : collection_arn ,
354+ "Type" : "Collection" ,
355+ }
356+ )
357+ return collection_details
358+
359+ def _convert_group_resource_response (
360+ self , group_resource_details : List [dict ], is_model_group : bool = False
361+ ):
362+ """Converts collection response from resource group api to collection list response
363+
364+ Args:
365+ group_resource_details (List[dict]): Collections list response from resource group api
366+ is_model_group (bool): If the reponse is of collection or model group type
367+ """
368+ collection_details = []
369+ if group_resource_details ["Resources" ]:
370+ for resource_group in group_resource_details ["Resources" ]:
371+ collection_arn = resource_group ["Identifier" ]["ResourceArn" ]
372+ collection_name = collection_arn .split ("group/" )[1 ]
373+ collection_details .append (
374+ {
375+ "Name" : collection_name ,
376+ "Arn" : collection_arn ,
377+ "Type" : resource_group ["Identifier" ]["ResourceType" ]
378+ if is_model_group
379+ else "Collection" ,
380+ }
381+ )
382+ return collection_details
383+
384+ def _get_full_list_resource (self , collection_name , collection_filter ):
385+ """Iterating to the full resource group list and returns appended paginated response
386+
387+ Args:
388+ collection_name (str): Name of the collection to get the details
389+ collection_filter (dict): Filter details to be passed to get the resource list
390+
391+ """
392+ list_group_response = self .sagemaker_session .list_group_resources (
393+ group = collection_name , filters = collection_filter
394+ )
395+ next_token = list_group_response .get ("NextToken" )
396+ while next_token is not None :
397+
398+ paginated_group_response = self .sagemaker_session .list_group_resources (
399+ group = collection_name ,
400+ filters = collection_filter ,
401+ next_token = next_token ,
402+ )
403+ list_group_response ["Resources" ] = (
404+ list_group_response ["Resources" ] + paginated_group_response ["Resources" ]
405+ )
406+ list_group_response ["ResourceIdentifiers" ] = (
407+ list_group_response ["ResourceIdentifiers" ]
408+ + paginated_group_response ["ResourceIdentifiers" ]
409+ )
410+ next_token = paginated_group_response .get ("NextToken" )
411+
412+ return list_group_response
413+
414+ def list_collection (self , collection_name : str = None ):
415+ """To all list the collections and content of the collections
416+
417+ In case there is no collection_name, it lists all the collections on the root level
418+
419+ Args:
420+ collection_name (str): The name of the collection to list the contents of
421+ """
422+ collection_content = []
423+ if collection_name is None :
424+ tag_filters = [
425+ {
426+ "Key" : "sagemaker:collection-path:root" ,
427+ "Values" : ["true" ],
428+ },
429+ ]
430+ resource_type_filters = ["resource-groups:group" ]
431+ tag_collections = self .sagemaker_session .get_tagging_resources (
432+ tag_filters = tag_filters , resource_type_filters = resource_type_filters
433+ )
434+
435+ return self ._convert_tag_collection_response (tag_collections )
436+
437+ collection_filter = [
438+ {
439+ "Name" : "resource-type" ,
440+ "Values" : ["AWS::ResourceGroups::Group" ],
441+ },
442+ ]
443+ list_group_response = self ._get_full_list_resource (
444+ collection_name = collection_name , collection_filter = collection_filter
445+ )
446+ collection_content = self ._convert_group_resource_response (list_group_response )
447+
448+ collection_filter = [
449+ {
450+ "Name" : "resource-type" ,
451+ "Values" : ["AWS::SageMaker::ModelPackageGroup" ],
452+ },
453+ ]
454+ list_group_response = self ._get_full_list_resource (
455+ collection_name = collection_name , collection_filter = collection_filter
456+ )
457+
458+ collection_content = collection_content + self ._convert_group_resource_response (
459+ list_group_response , True
460+ )
461+
462+ return collection_content
0 commit comments