@@ -143,10 +143,10 @@ def load(cls, artifact_arn: str, sagemaker_session=None) -> "Artifact":
143143 return artifact
144144
145145 def downstream_trials (self , sagemaker_session = None ) -> list :
146- """Retrieve all trial runs which that use this artifact.
146+ """Use the lineage API to retrieve all downstream trials that use this artifact.
147147
148148 Args:
149- sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session
149+ sagemaker_session (obj): Sagemaker Session to use. If not provided a default session
150150 will be created.
151151
152152 Returns:
@@ -159,6 +159,54 @@ def downstream_trials(self, sagemaker_session=None) -> list:
159159 )
160160 trial_component_arns : list = list (map (lambda x : x .destination_arn , outgoing_associations ))
161161
162+ return self ._get_trial_from_trial_component (trial_component_arns )
163+
164+ def downstream_trials_v2 (self ) -> list :
165+ """Use a lineage query to retrieve all downstream trials that use this artifact.
166+
167+ Returns:
168+ [Trial]: A list of SageMaker `Trial` objects.
169+ """
170+ return self ._trials (direction = LineageQueryDirectionEnum .DESCENDANTS )
171+
172+ def upstream_trials (self ) -> List :
173+ """Use the lineage query to retrieve all upstream trials that use this artifact.
174+
175+ Returns:
176+ [Trial]: A list of SageMaker `Trial` objects.
177+ """
178+ return self ._trials (direction = LineageQueryDirectionEnum .ASCENDANTS )
179+
180+ def _trials (
181+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .BOTH
182+ ) -> List :
183+ """Use the lineage query to retrieve all trials that use this artifact.
184+
185+ Args:
186+ direction (LineageQueryDirectionEnum, optional): The query direction.
187+
188+ Returns:
189+ [Trial]: A list of SageMaker `Trial` objects.
190+ """
191+ query_filter = LineageFilter (entities = [LineageEntityEnum .TRIAL_COMPONENT ])
192+ query_result = LineageQuery (self .sagemaker_session ).query (
193+ start_arns = [self .artifact_arn ],
194+ query_filter = query_filter ,
195+ direction = direction ,
196+ include_edges = False ,
197+ )
198+ trial_component_arns : list = list (map (lambda x : x .arn , query_result .vertices ))
199+ return self ._get_trial_from_trial_component (trial_component_arns )
200+
201+ def _get_trial_from_trial_component (self , trial_component_arns : list ) -> List :
202+ """Retrieve all upstream trial runs which that use the trial component arns.
203+
204+ Args:
205+ trial_component_arns (list): list of trial component arns
206+
207+ Returns:
208+ [Trial]: A list of SageMaker `Trial` objects.
209+ """
162210 if not trial_component_arns :
163211 # no outgoing associations for this artifact
164212 return []
@@ -170,7 +218,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
170218 num_search_batches = math .ceil (len (trial_component_arns ) % max_search_by_arn )
171219 trial_components : list = []
172220
173- sagemaker_session = sagemaker_session or _utils .default_session ()
221+ sagemaker_session = self . sagemaker_session or _utils .default_session ()
174222 sagemaker_client = sagemaker_session .sagemaker_client
175223
176224 for i in range (num_search_batches ):
@@ -335,6 +383,17 @@ def list(
335383 sagemaker_session = sagemaker_session ,
336384 )
337385
386+ def s3_uri_artifacts (self , s3_uri : str ) -> dict :
387+ """Retrieve a list of artifacts that use provided s3 uri.
388+
389+ Args:
390+ s3_uri (str): A S3 URI.
391+
392+ Returns:
393+ A list of ``Artifacts``
394+ """
395+ return self .sagemaker_session .sagemaker_client .list_artifacts (SourceUri = s3_uri )
396+
338397
339398class ModelArtifact (Artifact ):
340399 """A SageMaker lineage artifact representing a model.
@@ -349,7 +408,7 @@ def endpoints(self) -> list:
349408 """Get association summaries for endpoints deployed with this model.
350409
351410 Returns:
352- [AssociationSummary]: A list of associations repesenting the endpoints using the model.
411+ [AssociationSummary]: A list of associations representing the endpoints using the model.
353412 """
354413 endpoint_development_actions : Iterator = Association .list (
355414 source_arn = self .artifact_arn ,
@@ -522,3 +581,69 @@ def endpoint_contexts(
522581 for vertex in query_result .vertices :
523582 endpoint_contexts .append (vertex .to_lineage_object ())
524583 return endpoint_contexts
584+
585+ def upstream_datasets (self ) -> List [Artifact ]:
586+ """Use the lineage query to retrieve upstream artifacts that use this dataset artifact.
587+
588+ Returns:
589+ list of Artifacts: Artifacts representing an dataset.
590+ """
591+ return self ._datasets (direction = LineageQueryDirectionEnum .ASCENDANTS )
592+
593+ def downstream_datasets (self ) -> List [Artifact ]:
594+ """Use the lineage query to retrieve downstream artifacts that use this dataset.
595+
596+ Returns:
597+ list of Artifacts: Artifacts representing an dataset.
598+ """
599+ return self ._datasets (direction = LineageQueryDirectionEnum .DESCENDANTS )
600+
601+ def _datasets (
602+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .BOTH
603+ ) -> List [Artifact ]:
604+ """Use the lineage query to retrieve all artifacts that use this dataset.
605+
606+ Args:
607+ direction (LineageQueryDirectionEnum, optional): The query direction.
608+
609+ Returns:
610+ list of Artifacts: Artifacts representing an dataset.
611+ """
612+ query_filter = LineageFilter (
613+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
614+ )
615+ query_result = LineageQuery (self .sagemaker_session ).query (
616+ start_arns = [self .artifact_arn ],
617+ query_filter = query_filter ,
618+ direction = direction ,
619+ include_edges = False ,
620+ )
621+ return [vertex .to_lineage_object () for vertex in query_result .vertices ]
622+
623+
624+ class ImageArtifact (Artifact ):
625+ """A SageMaker lineage artifact representing an image.
626+
627+ Common model specific lineage traversals to discover how the image is connected
628+ to other entities.
629+ """
630+
631+ def datasets (self , direction : LineageQueryDirectionEnum ) -> List [Artifact ]:
632+ """Use the lineage query to retrieve datasets that use this image artifact.
633+
634+ Args:
635+ direction (LineageQueryDirectionEnum): The query direction.
636+
637+ Returns:
638+ list of Artifacts: Artifacts representing a dataset.
639+ """
640+ query_filter = LineageFilter (
641+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
642+ )
643+ query_result = LineageQuery (self .sagemaker_session ).query (
644+ start_arns = [self .artifact_arn ],
645+ query_filter = query_filter ,
646+ direction = direction ,
647+ include_edges = False ,
648+ )
649+ return [vertex .to_lineage_object () for vertex in query_result .vertices ]
0 commit comments