12
12
import warnings
13
13
from asyncio import AbstractEventLoop
14
14
from functools import partial , reduce
15
+ from itertools import count
15
16
from os import getenv
16
17
17
18
warnings .filterwarnings ("ignore" , "Valid config keys have changed in V2" )
21
22
Any ,
22
23
Generator ,
23
24
Generic ,
25
+ Iterable ,
24
26
Optional ,
25
27
Sequence ,
26
28
Tuple ,
@@ -952,50 +954,48 @@ def process_block_body(
952
954
yield_result (result .result (), block .kind )
953
955
trace = block .model_copy (update = {"pdl__trace" : iter_trace })
954
956
case MapBlock ():
955
- results = []
956
957
background = DependentContext ([])
957
- iter_trace = []
958
958
iteration_state = state .with_yield_result (False )
959
959
block , items , length = _evaluate_for_field (scope , block , loc )
960
960
block , max_iterations = _evaluate_max_iterations_field (scope , block , loc )
961
961
block = _evaluate_join_field (scope , block , loc )
962
962
map_loc = append (loc , "map" )
963
963
iidx = 0
964
964
try :
965
- saved_background = IndependentContext ([])
966
- while True :
965
+ if max_iterations is not None :
966
+ index_iterator : Any = range (max_iterations )
967
+ else :
968
+ index_iterator = count ()
969
+ if items is not None and length is not None :
970
+ items_iterator = (
971
+ {k : elems [i ] for k , elems in items .items ()}
972
+ for i in range (length )
973
+ )
974
+ else :
975
+ items_iterator = ({} for _ in count ())
976
+
977
+ def loop_body (iidx , items ):
967
978
iteration_scope = scope_init
968
979
if block .index is not None :
969
980
iteration_scope = iteration_scope | {block .index : iidx }
970
- if max_iterations is not None and iidx >= max_iterations :
971
- break
972
- if length is not None and iidx >= length :
973
- break
974
- iteration_state = iteration_state .with_iter (iidx )
975
- if items is not None :
976
- for k in items .keys ():
977
- iteration_scope = iteration_scope | {k : items [k ][iidx ]}
978
- (
979
- iteration_result ,
980
- iteration_background ,
981
- iteration_scope ,
982
- body_trace ,
983
- ) = process_block (
984
- iteration_state ,
981
+ iteration_scope = iteration_scope | items
982
+ return process_block (
983
+ iteration_state .with_iter (iidx ),
985
984
iteration_scope ,
986
985
block .map ,
987
986
map_loc ,
988
987
)
989
- saved_background = IndependentContext (
990
- [saved_background , iteration_background ]
991
- )
992
- results .append (iteration_result )
993
- iter_trace .append (body_trace )
994
- iteration_state = iteration_state .with_pop ()
995
- iidx = iidx + 1
988
+
989
+ # with ThreadPoolExecutor(max_workers=4) as executor:
990
+ # map_output = executor.map(
991
+ map_output = map ( # pylint: disable=bad-builtin
992
+ loop_body , index_iterator , items_iterator
993
+ )
994
+ results , _ , _ , traces = _split_map_output (map_output )
995
+ # saved_background = IndependentContext(backgrounds)
996
996
except PDLRuntimeError as exc :
997
- iter_trace . append ( exc .pdl__trace )
998
- trace = block .model_copy (update = {"pdl__trace" : iter_trace })
997
+ traces = [ exc .pdl__trace ] # type: ignore
998
+ trace = block .model_copy (update = {"pdl__trace" : traces })
999
999
raise PDLRuntimeError (
1000
1000
exc .message ,
1001
1001
loc = exc .loc or map_loc ,
@@ -1005,7 +1005,7 @@ def process_block_body(
1005
1005
# background = saved_background # commented because the block do not contribute to the background
1006
1006
if state .yield_result and not iteration_state .yield_result :
1007
1007
yield_result (result .result (), block .kind )
1008
- trace = block .model_copy (update = {"pdl__trace" : iter_trace })
1008
+ trace = block .model_copy (update = {"pdl__trace" : traces })
1009
1009
case ReadBlock ():
1010
1010
result , background , scope , trace = process_input (state , scope , block , loc )
1011
1011
if state .yield_result :
@@ -1065,6 +1065,21 @@ def process_block_body(
1065
1065
return result , background , scope , trace
1066
1066
1067
1067
1068
+ def _split_map_output (
1069
+ map_output : Iterable [Tuple [PdlLazy [Any ], LazyMessages , ScopeType , BlockType ]],
1070
+ ) -> Tuple [list [PdlLazy [Any ]], list [LazyMessages ], list [ScopeType ], list [BlockType ]]:
1071
+ results = []
1072
+ backgrounds = []
1073
+ scopes = []
1074
+ traces = []
1075
+ for result , background , scope , trace in map_output :
1076
+ results .append (result )
1077
+ backgrounds .append (background )
1078
+ scopes .append (scope )
1079
+ traces .append (trace )
1080
+ return results , backgrounds , scopes , traces
1081
+
1082
+
1068
1083
BlockTVarEvalFor = TypeVar ("BlockTVarEvalFor" , bound = RepeatBlock | MapBlock )
1069
1084
1070
1085
0 commit comments