Skip to content

Commit 218342e

Browse files
authored
feat: parallel map reduce (#1102)
Signed-off-by: Louis Mandel <[email protected]>
1 parent 57e0225 commit 218342e

File tree

5 files changed

+57
-5
lines changed

5 files changed

+57
-5
lines changed

pdl-live-react/src/pdl_ast.d.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2491,6 +2491,11 @@ export interface MapBlock {
24912491
*
24922492
*/
24932493
join?: JoinText | JoinArray | JoinObject | JoinLastOf | JoinReduce
2494+
/**
2495+
* Maximal number of workers to execute the map in parallel. Is it is set to `0`, the execution is sequential otherwise it is given as argument to the `ThreadPoolExecutor`.
2496+
*
2497+
*/
2498+
maxWorkers?: number | null
24942499
pdl__trace?: PdlTrace
24952500
}
24962501
/**

src/pdl/pdl-schema.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2944,6 +2944,11 @@
29442944
},
29452945
"description": "Define how to combine the result of each iteration.\n "
29462946
},
2947+
"maxWorkers": {
2948+
"$ref": "#/$defs/OptionalInt",
2949+
"default": null,
2950+
"description": "Maximal number of workers to execute the map in parallel. Is it is set to `0`, the execution is sequential otherwise it is given as argument to the `ThreadPoolExecutor`.\n "
2951+
},
29472952
"pdl__trace": {
29482953
"anyOf": [
29492954
{

src/pdl/pdl_ast.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,9 @@ class MapBlock(StructuredBlock):
10221022
join: JoinType = JoinText()
10231023
"""Define how to combine the result of each iteration.
10241024
"""
1025+
maxWorkers: OptionalInt = None
1026+
"""Maximal number of workers to execute the map in parallel. Is it is set to `0`, the execution is sequential otherwise it is given as argument to the `ThreadPoolExecutor`.
1027+
"""
10251028
# Field for internal use
10261029
pdl__trace: Optional[list["BlockType"]] = None
10271030

src/pdl/pdl_interpreter.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# TODO: temporarily disabling warnings to mute a pydantic warning from liteLLM
1212
import warnings
1313
from asyncio import AbstractEventLoop
14+
from concurrent.futures import ThreadPoolExecutor
1415
from functools import partial, reduce
1516
from itertools import count
1617
from os import getenv
@@ -986,11 +987,18 @@ def loop_body(iidx, items):
986987
map_loc,
987988
)
988989

989-
# with ThreadPoolExecutor(max_workers=4) as executor:
990-
# map_output = executor.map(
991-
map_output = map( # pylint: disable=bad-builtin
992-
loop_body, index_iterator, items_iterator
993-
)
990+
map_output: Iterable[
991+
Tuple[PdlLazy[Any], LazyMessages, ScopeType, BlockType]
992+
]
993+
if block.maxWorkers == 0:
994+
map_output = map( # pylint: disable=bad-builtin
995+
loop_body, index_iterator, items_iterator
996+
)
997+
else:
998+
with ThreadPoolExecutor(block.maxWorkers) as executor:
999+
map_output = executor.map(
1000+
loop_body, index_iterator, items_iterator
1001+
)
9941002
results, _, _, traces = _split_map_output(map_output)
9951003
# saved_background = IndependentContext(backgrounds)
9961004
except PDLRuntimeError as exc:

tests/test_for.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,25 @@ def test_for_context():
324324
def test_for_reduce():
325325
for loop_kind1 in ["repeat", "map"]:
326326
prog_str = f"""
327+
defs:
328+
plus:
329+
function:
330+
x: number
331+
y: number
332+
return: ${{ x + y }}
333+
for:
334+
i: [1,2,3,4]
335+
{loop_kind1}: ${{ i }}
336+
join:
337+
reduce: ${{ plus }}
338+
"""
339+
result = exec_str(prog_str)
340+
assert result == 10
341+
342+
343+
def test_for_reduce_python():
344+
for loop_kind1 in ["repeat", "map"]:
345+
prog_str = f"""
327346
defs:
328347
plus:
329348
lang: python
@@ -338,3 +357,15 @@ def test_for_reduce():
338357
"""
339358
result = exec_str(prog_str)
340359
assert result == 10
360+
361+
362+
def test_map_max_workers():
363+
for max_workers in [0, "null", 2, 4]:
364+
prog_str = f"""
365+
for:
366+
i: [1,2,3,4]
367+
map: ${{ i }}
368+
maxWorkers: {max_workers}
369+
"""
370+
result = exec_str(prog_str)
371+
assert result == "1234"

0 commit comments

Comments
 (0)