Skip to content

Commit c2dd71f

Browse files
Update for_helpers.py
1 parent 9476471 commit c2dd71f

File tree

1 file changed

+36
-28
lines changed

1 file changed

+36
-28
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,13 +1184,36 @@ def init(self, indexes: list[Lvalue], exprs: list[Expression]) -> None:
11841184
)
11851185
self.gens.append(gen)
11861186

1187+
self.conditions, self.cond_blocks = self.__sort_conditions()
1188+
11871189
def gen_condition(self) -> None:
1190+
for i, gen in enumerate(ordered + leftovers + for_iterable):
1191+
gen.gen_condition()
1192+
if i < len(self.gens) - 1:
1193+
self.builder.activate_block(self.cond_blocks[i])
1194+
1195+
def begin_body(self) -> None:
1196+
for gen in self.gens:
1197+
gen.begin_body()
1198+
1199+
def gen_step(self) -> None:
1200+
for gen in self.gens:
1201+
gen.gen_step()
1202+
1203+
def gen_cleanup(self) -> None:
1204+
for gen in self.gens:
1205+
gen.gen_cleanup()
1206+
1207+
def __sort_conditions(self) -> List[ForSequence]:
11881208
# We don't necessarily need to check the gens in order,
11891209
# we just need to know which gen ends first. Some gens
11901210
# are quicker to check than others, so we will check the
11911211
# specialized ForHelpers before we check any generic
11921212
# ForIterable
1213+
11931214
gens = self.gens
1215+
cond_blocks = self.cond_blocks[:]
1216+
cond_blocks.remove(self.body_block)
11941217

11951218
def check_type(obj: Any, typ: type[ForGenerator]) -> bool:
11961219
# ForEnumerate gen_condition is as fast as it's underlying generator's
@@ -1201,46 +1224,42 @@ def check_type(obj: Any, typ: type[ForGenerator]) -> bool:
12011224
)
12021225

12031226
# these are slowest, they invoke Python's iteration protocol
1204-
for_iterable = [g for g in gens if check_type(g, ForSequence)]
1227+
for_iterable = [(g, block) for g, block in zip(gens, cond_blocks) if check_type(g, ForSequence)]
12051228

12061229
# These aren't the slowest but they're slow, we need to pack an RTuple and then get and item and do a comparison
1207-
for_dict = [g for g in gens if check_type(g, ForDictionaryCommon)]
1230+
for_dict = [(g, block) for g, block in zip(gens, cond_blocks) if check_type(g, ForDictionaryCommon)]
12081231

12091232
# These are faster than ForIterable but not as fast as others (faster than ForDict?)
1210-
for_native = [g for g in gens if check_type(g, ForNativeGenerator)]
1233+
for_native = [(g, block) for g, block in zip(gens, cond_blocks) if check_type(g, ForNativeGenerator)]
12111234

12121235
# forward involves in the best case one pyssize_t comparison, else one length check + the comparison
12131236
# reverse is slightly slower than forward, with one extra check
12141237
for_sequence_reverse_with_len_check = [
1215-
g
1216-
for g in gens
1238+
(g, block) for g, block in zip(gens, cond_blocks)
12171239
if check_type(g, ForSequence)
12181240
and (
12191241
for_seq := cast(ForSequence, g.main_gen if isinstance(g, ForEnumerate) else g)
12201242
).reverse
12211243
and for_seq.length_reg is not None
12221244
]
12231245
for_sequence_reverse_no_len_check = [
1224-
g
1225-
for g in gens
1246+
(g, block) for g, block in zip(gens, cond_blocks)
12261247
if check_type(g, ForSequence)
12271248
and (
12281249
for_seq := cast(ForSequence, g.main_gen if isinstance(g, ForEnumerate) else g)
12291250
).reverse
12301251
and for_seq.length_reg is None
12311252
]
12321253
for_sequence_forward_with_len_check = [
1233-
g
1234-
for g in gens
1254+
(g, block) for g, block in zip(gens, cond_blocks)
12351255
if check_type(g, ForSequence)
12361256
and not (
12371257
for_seq := cast(ForSequence, g.main_gen if isinstance(g, ForEnumerate) else g)
12381258
).reverse
12391259
and for_seq.length_reg is not None
12401260
]
12411261
for_sequence_forward_no_len_check = [
1242-
g
1243-
for g in gens
1262+
(g, block) for g, block in zip(gens, cond_blocks)
12441263
if check_type(g, ForSequence)
12451264
and not (
12461265
for_seq := cast(ForSequence, g.main_gen if isinstance(g, ForEnumerate) else g)
@@ -1249,7 +1268,7 @@ def check_type(obj: Any, typ: type[ForGenerator]) -> bool:
12491268
]
12501269

12511270
# these are really fast, just a C int equality check
1252-
for_range = [g for g in gens if isinstance(g, ForRange)]
1271+
for_range = [(g, block) for g, block in zip(gens, cond_blocks) if isinstance(g, ForRange)]
12531272

12541273
ordered = (
12551274
for_range
@@ -1262,24 +1281,13 @@ def check_type(obj: Any, typ: type[ForGenerator]) -> bool:
12621281
)
12631282

12641283
# this is a failsafe for ForHelper classes which might have been added after this commit but not added to this function's code
1265-
leftovers = [g for g in gens if g not in ordered + for_iterable]
1266-
1267-
for i, gen in enumerate(ordered + leftovers + for_iterable):
1268-
gen.gen_condition()
1269-
if i < len(self.gens) - 1:
1270-
self.builder.activate_block(self.cond_blocks[i])
1271-
1272-
def begin_body(self) -> None:
1273-
for gen in self.gens:
1274-
gen.begin_body()
1284+
leftovers = [(g, block) for g, block in zip(gens, cond_blocks) if g not in ordered + for_iterable]
12751285

1276-
def gen_step(self) -> None:
1277-
for gen in self.gens:
1278-
gen.gen_step()
1286+
gens_and_blocks = ordered + leftovers + for_iterable
1287+
conditons = [g for (g, block) in gens_and_blocks]
1288+
cond_blocks = [block for (g, block) in gens_and_blocks] + [self.body_block]
12791289

1280-
def gen_cleanup(self) -> None:
1281-
for gen in self.gens:
1282-
gen.gen_cleanup()
1290+
return conditions, cond_blocks
12831291

12841292

12851293
def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None:

0 commit comments

Comments
 (0)