Skip to content

Commit ef77cfb

Browse files
authored
Support var-len params in func passed to executor. (#1457)
Fixes #1456
1 parent 5becdfa commit ef77cfb

File tree

3 files changed

+82
-14
lines changed

3 files changed

+82
-14
lines changed

lithops/tests/test_map.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,31 @@ def test_lithops_return_futures_map_multiple(self):
125125
fexec.wait()
126126
result = fexec.get_result()
127127
assert result == [1, 2, 3, 1, 2, 3]
128+
129+
def test_lithops_return_futures_map_over_decorator(self):
130+
def doubled(f):
131+
def wrapper(*args, **kwargs):
132+
return 2 * f(*args, **kwargs)
133+
134+
return wrapper
135+
136+
@doubled
137+
def total(a, b, c=0, *args, d, e=5, **kwargs):
138+
return sum([a, b, c, d, e, *args, *kwargs.values()])
139+
140+
with lithops.FunctionExecutor(config=pytest.lithops_config) as fexec:
141+
fexec.map(
142+
total,
143+
[
144+
{"a": 1, "b": 2, "d": 3},
145+
{"args": (1, 2), "d": 3},
146+
{"args": (1, 2), "d": 3, "e": 4},
147+
{"args": (1, 2), "d": 3, "f": 4},
148+
{"a": 1, "b": 2, "d": 3, "e": 6, "f": 4},
149+
{"args": (1, 2), "d": 3, "e": 6, "f": 4},
150+
{"args": (1, 2), "kwargs": {"d": 3, "e": 6, "f": 4}},
151+
],
152+
)
153+
result = fexec.get_result()
154+
155+
assert result == [22, 22, 20, 30, 32, 32, 32]

lithops/utils.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -533,26 +533,40 @@ def verify_args(func, iterdata, extra_args):
533533
non_verify_args = ['ibm_cos', 'storage', 'id', 'rabbitmq']
534534
func_sig = inspect.signature(func)
535535

536-
new_parameters = list()
537-
for param in func_sig.parameters:
538-
if param not in non_verify_args:
539-
new_parameters.append(func_sig.parameters[param])
540-
536+
new_parameters = [
537+
param
538+
for name, param in func_sig.parameters.items()
539+
if name not in non_verify_args
540+
]
541541
new_func_sig = func_sig.replace(parameters=new_parameters)
542542

543-
new_data = list()
543+
# Detect presence of **kwargs (with any name)
544+
has_var_keyword = any(
545+
p.kind == inspect.Parameter.VAR_KEYWORD
546+
for p in new_func_sig.parameters.values()
547+
)
548+
549+
new_data = []
550+
544551
for elem in data:
545552
if isinstance(elem, dict):
546-
if set(list(new_func_sig.parameters.keys())) <= set(elem):
553+
# If the function accepts **kwargs (any name), we cannot reliably
554+
# enforce exact param name matching here, and we *want* to allow
555+
# passing through arbitrary dicts (e.g., original function args)
556+
# even when a decorator wrapper has **kwargs, etc.
557+
if has_var_keyword:
558+
new_data.append(elem)
559+
elif set(expected_keys := list(new_func_sig.parameters)) <= set(elem):
560+
# No **kwargs: enforce that the dict contains at least all
561+
# required user parameters (excluding reserved ones).
547562
new_data.append(elem)
548563
else:
549-
raise ValueError("Check the args names in the data. "
550-
"You provided these args: {}, and "
551-
"the args must be: {}"
552-
.format(list(elem.keys()),
553-
list(new_func_sig.parameters.keys())))
564+
raise ValueError(
565+
"Check the args names in the data. You provided these args: ",
566+
f"{list(elem)}, and the args must be: {expected_keys}",
567+
)
554568
elif isinstance(elem, tuple):
555-
new_elem = dict(new_func_sig.bind(*list(elem)).arguments)
569+
new_elem = dict(new_func_sig.bind(*elem).arguments)
556570
new_data.append(new_elem)
557571
else:
558572
# single value (list, string, integer, dict, etc)

lithops/worker/jobrunner.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ def run(self):
234234
logger.info(f"Going to execute '{str(fn_name)}()'")
235235
print('---------------------- FUNCTION LOG ----------------------')
236236
function_start_tstamp = time.time()
237-
result = func(**data)
237+
args, kwargs = _prepare_args(func, data)
238+
result = func(*args, **kwargs)
238239
function_end_tstamp = time.time()
239240
print('----------------------------------------------------------')
240241
logger.info("Success function execution")
@@ -310,3 +311,28 @@ def run(self):
310311
self.stats.write("worker_result_upload_time", round(output_upload_end_tstamp - output_upload_start_tstamp, 8))
311312
self.jobrunner_conn.send("Finished")
312313
logger.info("Process finished")
314+
315+
316+
def _prepare_args(func, data):
317+
# Convert the "data" envelope into normal *args/**kwargs,
318+
# respecting the actual var-length parameter names of `func`.
319+
func_sig = inspect.signature(func)
320+
var_pos_name = None
321+
var_kw_name = None
322+
323+
for name, param in func_sig.parameters.items():
324+
if param.kind == inspect.Parameter.VAR_POSITIONAL:
325+
var_pos_name = name
326+
elif param.kind == inspect.Parameter.VAR_KEYWORD:
327+
var_kw_name = name
328+
329+
payload = dict(data)
330+
331+
# Extract var-positional argument value if present
332+
args = payload.pop(var_pos_name) or () if var_pos_name in payload else ()
333+
# Extract var-keyword argument value if present
334+
kwargs = payload.pop(var_kw_name) or {} if var_kw_name in payload else {}
335+
# Any remaining keys become normal keyword arguments
336+
kwargs.update(payload)
337+
338+
return args, kwargs

0 commit comments

Comments
 (0)