1717import math
1818
1919from datetime import datetime
20- from typing import Iterator , Union , Any , Optional
20+ from typing import Iterator , Union , Any , Optional , List
2121
2222from sagemaker .apiutils import _base_types , _utils
2323from sagemaker .lineage import _api_types
2424from sagemaker .lineage ._api_types import ArtifactSource , ArtifactSummary
25- from sagemaker .lineage ._utils import get_module , _disassociate
25+ from sagemaker .lineage .query import (
26+ LineageQuery ,
27+ LineageFilter ,
28+ LineageSourceEnum ,
29+ LineageEntityEnum ,
30+ LineageQueryDirectionEnum ,
31+ )
32+ from sagemaker .lineage ._utils import get_module , _disassociate , get_resource_name_from_arn
2633from sagemaker .lineage .association import Association
2734
2835LOGGER = logging .getLogger ("sagemaker" )
@@ -333,11 +340,13 @@ class ModelArtifact(Artifact):
333340 """A SageMaker lineage artifact representing a model.
334341
335342 Common model specific lineage traversals to discover how the model is connected
336- to otherentities .
343+ to other entities .
337344 """
338345
346+ from sagemaker .lineage .context import Context
347+
339348 def endpoints (self ) -> list :
340- """Given a model artifact, get all associated endpoint context .
349+ """Get association summaries for endpoints deployed with this model .
341350
342351 Returns:
343352 [AssociationSummary]: A list of associations repesenting the endpoints using the model.
@@ -359,6 +368,104 @@ def endpoints(self) -> list:
359368 ]
360369 return endpoint_context_list
361370
371+ def endpoint_contexts (
372+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .DESCENDANTS
373+ ) -> List [Context ]:
374+ """Get contexts representing endpoints from the models's lineage.
375+
376+ Args:
377+ direction (LineageQueryDirectionEnum, optional): The query direction.
378+
379+ Returns:
380+ list of Contexts: Contexts representing an endpoint.
381+ """
382+ query_filter = LineageFilter (
383+ entities = [LineageEntityEnum .CONTEXT ], sources = [LineageSourceEnum .ENDPOINT ]
384+ )
385+ query_result = LineageQuery (self .sagemaker_session ).query (
386+ start_arns = [self .artifact_arn ],
387+ query_filter = query_filter ,
388+ direction = direction ,
389+ include_edges = False ,
390+ )
391+
392+ endpoint_contexts = []
393+ for vertex in query_result .vertices :
394+ endpoint_contexts .append (vertex .to_lineage_object ())
395+ return endpoint_contexts
396+
397+ def dataset_artifacts (
398+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .ASCENDANTS
399+ ) -> List [Artifact ]:
400+ """Get artifacts representing datasets from the model's lineage.
401+
402+ Args:
403+ direction (LineageQueryDirectionEnum, optional): The query direction.
404+
405+ Returns:
406+ list of Artifacts: Artifacts representing a dataset.
407+ """
408+ query_filter = LineageFilter (
409+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
410+ )
411+ query_result = LineageQuery (self .sagemaker_session ).query (
412+ start_arns = [self .artifact_arn ],
413+ query_filter = query_filter ,
414+ direction = direction ,
415+ include_edges = False ,
416+ )
417+
418+ dataset_artifacts = []
419+ for vertex in query_result .vertices :
420+ dataset_artifacts .append (vertex .to_lineage_object ())
421+ return dataset_artifacts
422+
423+ def training_job_arns (
424+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .ASCENDANTS
425+ ) -> List [str ]:
426+ """Get ARNs for all training jobs that appear in the model's lineage.
427+
428+ Returns:
429+ list of str: Training job ARNs.
430+ """
431+ query_filter = LineageFilter (
432+ entities = [LineageEntityEnum .TRIAL_COMPONENT ], sources = [LineageSourceEnum .TRAINING_JOB ]
433+ )
434+ query_result = LineageQuery (self .sagemaker_session ).query (
435+ start_arns = [self .artifact_arn ],
436+ query_filter = query_filter ,
437+ direction = direction ,
438+ include_edges = False ,
439+ )
440+
441+ training_job_arns = []
442+ for vertex in query_result .vertices :
443+ trial_component_name = get_resource_name_from_arn (vertex .arn )
444+ trial_component = self .sagemaker_session .sagemaker_client .describe_trial_component (
445+ TrialComponentName = trial_component_name
446+ )
447+ training_job_arns .append (trial_component ["Source" ]["SourceArn" ])
448+ return training_job_arns
449+
450+ def pipeline_execution_arn (
451+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .ASCENDANTS
452+ ) -> str :
453+ """Get the ARN for the pipeline execution associated with this model (if any).
454+
455+ Returns:
456+ str: A pipeline execution ARN.
457+ """
458+ training_job_arns = self .training_job_arns (direction = direction )
459+ for training_job_arn in training_job_arns :
460+ tags = self .sagemaker_session .sagemaker_client .list_tags (ResourceArn = training_job_arn )[
461+ "Tags"
462+ ]
463+ for tag in tags :
464+ if tag ["Key" ] == "sagemaker:pipeline-execution-arn" :
465+ return tag ["Value" ]
466+
467+ return None
468+
362469
363470class DatasetArtifact (Artifact ):
364471 """A SageMaker Lineage artifact representing a dataset.
@@ -367,7 +474,9 @@ class DatasetArtifact(Artifact):
367474 connect to related entities.
368475 """
369476
370- def trained_models (self ) -> list :
477+ from sagemaker .lineage .context import Context
478+
479+ def trained_models (self ) -> List [Association ]:
371480 """Given a dataset artifact, get associated trained models.
372481
373482 Returns:
@@ -387,3 +496,29 @@ def trained_models(self) -> list:
387496 result .extend (models )
388497
389498 return result
499+
500+ def endpoint_contexts (
501+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .DESCENDANTS
502+ ) -> List [Context ]:
503+ """Get contexts representing endpoints from the dataset's lineage.
504+
505+ Args:
506+ direction (LineageQueryDirectionEnum, optional): The query direction.
507+
508+ Returns:
509+ list of Contexts: Contexts representing an endpoint.
510+ """
511+ query_filter = LineageFilter (
512+ entities = [LineageEntityEnum .CONTEXT ], sources = [LineageSourceEnum .ENDPOINT ]
513+ )
514+ query_result = LineageQuery (self .sagemaker_session ).query (
515+ start_arns = [self .artifact_arn ],
516+ query_filter = query_filter ,
517+ direction = direction ,
518+ include_edges = False ,
519+ )
520+
521+ endpoint_contexts = []
522+ for vertex in query_result .vertices :
523+ endpoint_contexts .append (vertex .to_lineage_object ())
524+ return endpoint_contexts
0 commit comments