Skip to content

Commit b26bf2f

Browse files
Fix executor handling for collected list artifacts (#12707)
* Fix executor handling for collected list artifacts Signed-off-by: sduvvuri1603 <[email protected]> * Format executor loop per yapf Signed-off-by: sduvvuri1603 <[email protected]> * chore: normalize executor string quotes Signed-off-by: sduvvuri1603 <[email protected]> * refactor: dedupe list artifact resolver Signed-off-by: sduvvuri1603 <[email protected]> Co-authored-by: Cursor <[email protected]> * chore: docformat executor helper docstring Signed-off-by: sduvvuri1603 <[email protected]> Co-authored-by: Cursor <[email protected]> --------- Signed-off-by: sduvvuri1603 <[email protected]> Co-authored-by: Cursor <[email protected]>
1 parent 77b319f commit b26bf2f

File tree

2 files changed

+89
-4
lines changed

2 files changed

+89
-4
lines changed

sdk/python/kfp/dsl/executor.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,8 @@ def assign_input_and_output_artifacts(self) -> None:
7171
type_annotations.is_list_of_artifacts(annotation.__origin__)
7272
) or type_annotations.is_list_of_artifacts(annotation)
7373
if is_list_of_artifacts:
74-
# Get the annotation of the inner type of the list
75-
# to use when creating the artifacts
76-
inner_annotation = type_annotations.get_inner_type(
77-
annotation)
74+
inner_annotation = self._resolve_list_artifact_inner_type(
75+
input_name=name, list_annotation=annotation)
7876

7977
self.input_artifacts[name] = [
8078
self.make_artifact(
@@ -104,6 +102,41 @@ def assign_input_and_output_artifacts(self) -> None:
104102
self.output_artifacts[name] = output_artifact
105103
makedirs_recursively(output_artifact.path)
106104

105+
def _resolve_list_artifact_inner_type(self, input_name: str,
106+
list_annotation: Any) -> Any:
107+
"""Unwraps nested list annotations and validates the inner artifact
108+
type."""
109+
inner_annotation = type_annotations.get_inner_type(list_annotation)
110+
111+
if inner_annotation is None:
112+
raise TypeError(
113+
f"Input '{input_name}' expects a list of artifacts, but "
114+
f'received {list_annotation!r} without an inner type.')
115+
116+
while type_annotations.is_list_of_artifacts(inner_annotation):
117+
if isinstance(inner_annotation, tuple):
118+
raise TypeError(
119+
f"Input '{input_name}' expects a single artifact type but "
120+
f'received a union {inner_annotation!r}.')
121+
inner_annotation = type_annotations.get_inner_type(inner_annotation)
122+
if inner_annotation is None:
123+
raise TypeError(
124+
f"Input '{input_name}' expects a list of artifacts, but "
125+
f'could not determine the inner annotation from '
126+
f'{list_annotation!r}.')
127+
128+
if isinstance(inner_annotation, tuple):
129+
raise TypeError(
130+
f"Input '{input_name}' expects a single artifact type but "
131+
f'received a union {inner_annotation!r}.')
132+
if not type_annotations.is_artifact_class(inner_annotation):
133+
raise TypeError(
134+
f"Input '{input_name}' expects a list of artifacts, but "
135+
f'received {list_annotation!r} whose inner type '
136+
f'{inner_annotation!r} is not an artifact.')
137+
138+
return inner_annotation
139+
107140
def make_artifact(
108141
self,
109142
runtime_artifact: Dict,
@@ -113,6 +146,9 @@ def make_artifact(
113146
) -> Any:
114147
annotation = func.__annotations__.get(
115148
name) if annotation is None else annotation
149+
if type_annotations.is_list_of_artifacts(annotation):
150+
annotation = self._resolve_list_artifact_inner_type(
151+
input_name=name, list_annotation=annotation)
116152
if isinstance(annotation, type_annotations.InputPath):
117153
schema_title, _ = annotation.type.split('@')
118154
if schema_title in artifact_types._SCHEMA_TITLE_TO_TYPE:

sdk/python/kfp/dsl/executor_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,55 @@ def test_func(input_artifact: Input[Artifact]):
12631263

12641264
self.assertDictEqual(output_metadata, {})
12651265

1266+
def test_list_of_artifact_input(self):
1267+
executor_input = """\
1268+
{
1269+
"inputs": {
1270+
"artifacts": {
1271+
"input_datasets": {
1272+
"artifacts": [
1273+
{
1274+
"metadata": {"rows": 10},
1275+
"name": "datasets/0",
1276+
"type": {
1277+
"schemaTitle": "system.Dataset"
1278+
},
1279+
"uri": "gs://some-bucket/output/input_dataset_0"
1280+
},
1281+
{
1282+
"metadata": {"rows": 20},
1283+
"name": "datasets/1",
1284+
"type": {
1285+
"schemaTitle": "system.Dataset"
1286+
},
1287+
"uri": "gs://some-bucket/output/input_dataset_1"
1288+
}
1289+
]
1290+
}
1291+
}
1292+
},
1293+
"outputs": {
1294+
"outputFile": "%(test_dir)s/output_metadata.json"
1295+
}
1296+
}
1297+
"""
1298+
1299+
def test_func(input_datasets: Input[List[Dataset]]):
1300+
self.assertIsInstance(input_datasets, list)
1301+
self.assertLen(input_datasets, 2)
1302+
for index, artifact in enumerate(input_datasets):
1303+
self.assertIsInstance(artifact, Dataset)
1304+
self.assertEqual(artifact.name, f'datasets/{index}')
1305+
self.assertEqual(
1306+
artifact.uri,
1307+
f'gs://some-bucket/output/input_dataset_{index}')
1308+
self.assertEqual(artifact.metadata['rows'], (index + 1) * 10)
1309+
1310+
output_metadata = self.execute_and_load_output_metadata(
1311+
test_func, executor_input)
1312+
1313+
self.assertDictEqual(output_metadata, {})
1314+
12661315
def test_single_artifact_input_pythonic(self):
12671316
executor_input = """\
12681317
{

0 commit comments

Comments
 (0)