@@ -290,7 +290,7 @@ def many_outputs(a: int) -> (int, str):
290290 _ = map_task (many_outputs )
291291
292292
293- def test_parameter_order ():
293+ def test_partials_local_execute ():
294294 @task ()
295295 def task1 (a : int , b : float , c : str ) -> str :
296296 return f"{ a } - { b } - { c } "
@@ -303,20 +303,40 @@ def task2(b: float, c: str, a: int) -> str:
303303 def task3 (c : str , a : int , b : float ) -> str :
304304 return f"{ a } - { b } - { c } "
305305
306+ @task ()
307+ def task4 (c : list [str ], a : list [int ], b : list [float ]) -> list [str ]:
308+ return [f"{ a [i ]} - { b [i ]} - { c [i ]} " for i in range (len (a ))]
309+
306310 param_a = [1 , 2 , 3 ]
307311 param_b = [0.1 , 0.2 , 0.3 ]
308- param_c = "c"
312+ fixed_param_c = "c"
309313
310- m1 = map_task (functools .partial (task1 , c = param_c ))(a = param_a , b = param_b )
311- m2 = map_task (functools .partial (task2 , c = param_c ))(a = param_a , b = param_b )
312- m3 = map_task (functools .partial (task3 , c = param_c ))(a = param_a , b = param_b )
314+ m1 = map_task (functools .partial (task1 , c = fixed_param_c ))(a = param_a , b = param_b )
315+ m2 = map_task (functools .partial (task2 , c = fixed_param_c ))(a = param_a , b = param_b )
316+ m3 = map_task (functools .partial (task3 , c = fixed_param_c ))(a = param_a , b = param_b )
313317
314- m4 = ArrayNodeMapTask (task1 , bound_inputs_values = {"c" : param_c })(a = param_a , b = param_b )
315- m5 = ArrayNodeMapTask (task2 , bound_inputs_values = {"c" : param_c })(a = param_a , b = param_b )
316- m6 = ArrayNodeMapTask (task3 , bound_inputs_values = {"c" : param_c })(a = param_a , b = param_b )
318+ m4 = ArrayNodeMapTask (task1 , bound_inputs_values = {"c" : fixed_param_c })(a = param_a , b = param_b )
319+ m5 = ArrayNodeMapTask (task2 , bound_inputs_values = {"c" : fixed_param_c })(a = param_a , b = param_b )
320+ m6 = ArrayNodeMapTask (task3 , bound_inputs_values = {"c" : fixed_param_c })(a = param_a , b = param_b )
317321
318322 assert m1 == m2 == m3 == m4 == m5 == m6 == ["1 - 0.1 - c" , "2 - 0.2 - c" , "3 - 0.3 - c" ]
319323
324+ list_param_a = [[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]]
325+ list_param_b = [[0.1 , 0.2 , 0.3 ], [0.4 , 0.5 , 0.6 ], [0.7 , 0.8 , 0.9 ]]
326+ fixed_list_param_c = ["c" , "d" , "e" ]
327+
328+ m7 = map_task (functools .partial (task4 , c = fixed_list_param_c ))(a = list_param_a , b = list_param_b )
329+ m8 = ArrayNodeMapTask (task4 , bound_inputs_values = {"c" : fixed_list_param_c })(a = list_param_a , b = list_param_b )
330+
331+ assert m7 == m8 == [
332+ ['1 - 0.1 - c' , '2 - 0.2 - d' , '3 - 0.3 - e' ],
333+ ['4 - 0.4 - c' , '5 - 0.5 - d' , '6 - 0.6 - e' ],
334+ ['7 - 0.7 - c' , '8 - 0.8 - d' , '9 - 0.9 - e' ]
335+ ]
336+
337+ with pytest .raises (ValueError ):
338+ map_task (functools .partial (task1 , c = fixed_list_param_c ))(a = param_a , b = param_b )
339+
320340
321341def test_bounded_inputs_vars_order (serialization_settings ):
322342 @task ()
0 commit comments