Skip to content

Commit 0a54f9b

Browse files
authored
refactor: use map function to implement map block (#1096)
Signed-off-by: Louis Mandel <[email protected]>
1 parent d61477f commit 0a54f9b

File tree

1 file changed

+44
-29
lines changed

1 file changed

+44
-29
lines changed

src/pdl/pdl_interpreter.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import warnings
1313
from asyncio import AbstractEventLoop
1414
from functools import partial, reduce
15+
from itertools import count
1516
from os import getenv
1617

1718
warnings.filterwarnings("ignore", "Valid config keys have changed in V2")
@@ -21,6 +22,7 @@
2122
Any,
2223
Generator,
2324
Generic,
25+
Iterable,
2426
Optional,
2527
Sequence,
2628
Tuple,
@@ -952,50 +954,48 @@ def process_block_body(
952954
yield_result(result.result(), block.kind)
953955
trace = block.model_copy(update={"pdl__trace": iter_trace})
954956
case MapBlock():
955-
results = []
956957
background = DependentContext([])
957-
iter_trace = []
958958
iteration_state = state.with_yield_result(False)
959959
block, items, length = _evaluate_for_field(scope, block, loc)
960960
block, max_iterations = _evaluate_max_iterations_field(scope, block, loc)
961961
block = _evaluate_join_field(scope, block, loc)
962962
map_loc = append(loc, "map")
963963
iidx = 0
964964
try:
965-
saved_background = IndependentContext([])
966-
while True:
965+
if max_iterations is not None:
966+
index_iterator: Any = range(max_iterations)
967+
else:
968+
index_iterator = count()
969+
if items is not None and length is not None:
970+
items_iterator = (
971+
{k: elems[i] for k, elems in items.items()}
972+
for i in range(length)
973+
)
974+
else:
975+
items_iterator = ({} for _ in count())
976+
977+
def loop_body(iidx, items):
967978
iteration_scope = scope_init
968979
if block.index is not None:
969980
iteration_scope = iteration_scope | {block.index: iidx}
970-
if max_iterations is not None and iidx >= max_iterations:
971-
break
972-
if length is not None and iidx >= length:
973-
break
974-
iteration_state = iteration_state.with_iter(iidx)
975-
if items is not None:
976-
for k in items.keys():
977-
iteration_scope = iteration_scope | {k: items[k][iidx]}
978-
(
979-
iteration_result,
980-
iteration_background,
981-
iteration_scope,
982-
body_trace,
983-
) = process_block(
984-
iteration_state,
981+
iteration_scope = iteration_scope | items
982+
return process_block(
983+
iteration_state.with_iter(iidx),
985984
iteration_scope,
986985
block.map,
987986
map_loc,
988987
)
989-
saved_background = IndependentContext(
990-
[saved_background, iteration_background]
991-
)
992-
results.append(iteration_result)
993-
iter_trace.append(body_trace)
994-
iteration_state = iteration_state.with_pop()
995-
iidx = iidx + 1
988+
989+
# with ThreadPoolExecutor(max_workers=4) as executor:
990+
# map_output = executor.map(
991+
map_output = map( # pylint: disable=bad-builtin
992+
loop_body, index_iterator, items_iterator
993+
)
994+
results, _, _, traces = _split_map_output(map_output)
995+
# saved_background = IndependentContext(backgrounds)
996996
except PDLRuntimeError as exc:
997-
iter_trace.append(exc.pdl__trace)
998-
trace = block.model_copy(update={"pdl__trace": iter_trace})
997+
traces = [exc.pdl__trace] # type: ignore
998+
trace = block.model_copy(update={"pdl__trace": traces})
999999
raise PDLRuntimeError(
10001000
exc.message,
10011001
loc=exc.loc or map_loc,
@@ -1005,7 +1005,7 @@ def process_block_body(
10051005
# background = saved_background # commented because the block do not contribute to the background
10061006
if state.yield_result and not iteration_state.yield_result:
10071007
yield_result(result.result(), block.kind)
1008-
trace = block.model_copy(update={"pdl__trace": iter_trace})
1008+
trace = block.model_copy(update={"pdl__trace": traces})
10091009
case ReadBlock():
10101010
result, background, scope, trace = process_input(state, scope, block, loc)
10111011
if state.yield_result:
@@ -1065,6 +1065,21 @@ def process_block_body(
10651065
return result, background, scope, trace
10661066

10671067

1068+
def _split_map_output(
1069+
map_output: Iterable[Tuple[PdlLazy[Any], LazyMessages, ScopeType, BlockType]],
1070+
) -> Tuple[list[PdlLazy[Any]], list[LazyMessages], list[ScopeType], list[BlockType]]:
1071+
results = []
1072+
backgrounds = []
1073+
scopes = []
1074+
traces = []
1075+
for result, background, scope, trace in map_output:
1076+
results.append(result)
1077+
backgrounds.append(background)
1078+
scopes.append(scope)
1079+
traces.append(trace)
1080+
return results, backgrounds, scopes, traces
1081+
1082+
10681083
BlockTVarEvalFor = TypeVar("BlockTVarEvalFor", bound=RepeatBlock | MapBlock)
10691084

10701085

0 commit comments

Comments
 (0)