Skip to content

Commit 2f667e9

Browse files
authored
remove check of settings lists in partials (#3187)
* remove check of settings lists in partials Signed-off-by: Paul Dittamo <pvdittamo@gmail.com> * update unit test Signed-off-by: Paul Dittamo <pvdittamo@gmail.com> * test Signed-off-by: Paul Dittamo <pvdittamo@gmail.com> * clean up Signed-off-by: Paul Dittamo <pvdittamo@gmail.com> --------- Signed-off-by: Paul Dittamo <pvdittamo@gmail.com>
1 parent 97ea801 commit 2f667e9

File tree

3 files changed

+28
-13
lines changed

3 files changed

+28
-13
lines changed

flytekit/core/array_node_map_task.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,6 @@ def __init__(
5555
"""
5656
self._partial = None
5757
if isinstance(python_function_task, functools.partial):
58-
# TODO: We should be able to support partial tasks with lists as inputs
59-
for arg in python_function_task.keywords.values():
60-
if isinstance(arg, list):
61-
raise ValueError("Cannot use a partial task with lists as inputs")
6258
self._partial = python_function_task
6359
actual_task = self._partial.func
6460
else:

tests/flytekit/unit/core/test_array_node_map_task.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def many_outputs(a: int) -> (int, str):
290290
_ = map_task(many_outputs)
291291

292292

293-
def test_parameter_order():
293+
def test_partials_local_execute():
294294
@task()
295295
def task1(a: int, b: float, c: str) -> str:
296296
return f"{a} - {b} - {c}"
@@ -303,20 +303,40 @@ def task2(b: float, c: str, a: int) -> str:
303303
def task3(c: str, a: int, b: float) -> str:
304304
return f"{a} - {b} - {c}"
305305

306+
@task()
307+
def task4(c: list[str], a: list[int], b: list[float]) -> list[str]:
308+
return [f"{a[i]} - {b[i]} - {c[i]}" for i in range(len(a))]
309+
306310
param_a = [1, 2, 3]
307311
param_b = [0.1, 0.2, 0.3]
308-
param_c = "c"
312+
fixed_param_c = "c"
309313

310-
m1 = map_task(functools.partial(task1, c=param_c))(a=param_a, b=param_b)
311-
m2 = map_task(functools.partial(task2, c=param_c))(a=param_a, b=param_b)
312-
m3 = map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b)
314+
m1 = map_task(functools.partial(task1, c=fixed_param_c))(a=param_a, b=param_b)
315+
m2 = map_task(functools.partial(task2, c=fixed_param_c))(a=param_a, b=param_b)
316+
m3 = map_task(functools.partial(task3, c=fixed_param_c))(a=param_a, b=param_b)
313317

314-
m4 = ArrayNodeMapTask(task1, bound_inputs_values={"c": param_c})(a=param_a, b=param_b)
315-
m5 = ArrayNodeMapTask(task2, bound_inputs_values={"c": param_c})(a=param_a, b=param_b)
316-
m6 = ArrayNodeMapTask(task3, bound_inputs_values={"c": param_c})(a=param_a, b=param_b)
318+
m4 = ArrayNodeMapTask(task1, bound_inputs_values={"c": fixed_param_c})(a=param_a, b=param_b)
319+
m5 = ArrayNodeMapTask(task2, bound_inputs_values={"c": fixed_param_c})(a=param_a, b=param_b)
320+
m6 = ArrayNodeMapTask(task3, bound_inputs_values={"c": fixed_param_c})(a=param_a, b=param_b)
317321

318322
assert m1 == m2 == m3 == m4 == m5 == m6 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"]
319323

324+
list_param_a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
325+
list_param_b = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
326+
fixed_list_param_c = ["c", "d", "e"]
327+
328+
m7 = map_task(functools.partial(task4, c=fixed_list_param_c))(a=list_param_a, b=list_param_b)
329+
m8 = ArrayNodeMapTask(task4, bound_inputs_values={"c": fixed_list_param_c})(a=list_param_a, b=list_param_b)
330+
331+
assert m7 == m8 == [
332+
['1 - 0.1 - c', '2 - 0.2 - d', '3 - 0.3 - e'],
333+
['4 - 0.4 - c', '5 - 0.5 - d', '6 - 0.6 - e'],
334+
['7 - 0.7 - c', '8 - 0.8 - d', '9 - 0.9 - e']
335+
]
336+
337+
with pytest.raises(ValueError):
338+
map_task(functools.partial(task1, c=fixed_list_param_c))(a=param_a, b=param_b)
339+
320340

321341
def test_bounded_inputs_vars_order(serialization_settings):
322342
@task()

tests/flytekit/unit/core/test_partials.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def wf_in(a: typing.List[int]):
142142
"map_task_fn",
143143
[
144144
legacy_map_task,
145-
array_node_map_task,
146145
],
147146
)
148147
def test_lists_cannot_be_used_in_partials(map_task_fn):

0 commit comments

Comments
 (0)