Skip to content

Commit 1ce19eb

Browse files
committed
[mypyc] feat: cache len for iterating over immutable types
1 parent 5a78607 commit 1ce19eb

File tree

4 files changed

+47
-11
lines changed

4 files changed

+47
-11
lines changed

mypyc/ir/rtypes.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,19 @@ def is_range_rprimitive(rtype: RType) -> bool:
628628

629629
def is_sequence_rprimitive(rtype: RType) -> bool:
630630
return isinstance(rtype, RPrimitive) and (
631-
is_list_rprimitive(rtype) or is_tuple_rprimitive(rtype) or is_str_rprimitive(rtype)
631+
is_list_rprimitive(rtype)
632+
or is_tuple_rprimitive(rtype)
633+
or is_str_rprimitive(rtype)
634+
or is_bytes_rprimitive(rtype)
635+
)
636+
637+
638+
def is_immutable_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
639+
return (
640+
is_str_rprimitive(rtype)
641+
or is_bytes_rprimitive(rtype)
642+
or is_tuple_rprimitive(rtype)
643+
or is_frozenset_rprimitive(rtype)
632644
)
633645

634646

mypyc/irbuild/builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
RType,
9292
RUnion,
9393
bitmap_rprimitive,
94+
bytes_rprimitive,
9495
c_pyssize_t_rprimitive,
9596
dict_rprimitive,
9697
int_rprimitive,
@@ -962,8 +963,12 @@ def get_sequence_type_from_type(self, target_type: Type) -> RType:
962963
elif isinstance(target_type, Instance):
963964
if target_type.type.fullname == "builtins.str":
964965
return str_rprimitive
965-
else:
966+
elif target_type.type.fullname == "builtins.bytes":
967+
return bytes_rprimitive
968+
try:
966969
return self.type_to_rtype(target_type.args[0])
970+
except IndexError:
971+
raise ValueError(f"{target_type!r} is not a valid sequence.") from None
967972
# This elif-blocks are needed for iterating over classes derived from NamedTuple.
968973
elif isinstance(target_type, TypeVarLikeType):
969974
return self.get_sequence_type_from_type(target_type.upper_bound)

mypyc/irbuild/for_helpers.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
int_rprimitive,
4949
is_dict_rprimitive,
5050
is_fixed_width_rtype,
51+
is_immutable_rprimitive,
5152
is_list_rprimitive,
5253
is_sequence_rprimitive,
5354
is_short_int_rprimitive,
@@ -205,9 +206,9 @@ def sequence_from_generator_preallocate_helper(
205206
there is no condition list in the generator and only one original sequence with
206207
one index is allowed.
207208
208-
e.g. (1) tuple(f(x) for x in a_list/a_tuple)
209-
(2) list(f(x) for x in a_list/a_tuple)
210-
(3) [f(x) for x in a_list/a_tuple]
209+
e.g. (1) tuple(f(x) for x in a_list/a_tuple/a_str/a_bytes)
210+
(2) list(f(x) for x in a_list/a_tuple/a_str/a_bytes)
211+
(3) [f(x) for x in a_list/a_tuple/a_str/a_bytes]
211212
RTuple as an original sequence is not supported yet.
212213
213214
Args:
@@ -224,7 +225,7 @@ def sequence_from_generator_preallocate_helper(
224225
"""
225226
if len(gen.sequences) == 1 and len(gen.indices) == 1 and len(gen.condlists[0]) == 0:
226227
rtype = builder.node_type(gen.sequences[0])
227-
if is_list_rprimitive(rtype) or is_tuple_rprimitive(rtype) or is_str_rprimitive(rtype):
228+
if is_sequence_rprimitive(rtype):
228229
sequence = builder.accept(gen.sequences[0])
229230
length = builder.builder.builtin_len(sequence, gen.line, use_pyssize_t=True)
230231
target_op = empty_op_llbuilder(length, gen.line)
@@ -785,17 +786,31 @@ class ForSequence(ForGenerator):
785786
Supports iterating in both forward and reverse.
786787
"""
787788

789+
length_reg: Value | AssignmentTarget | None
790+
788791
def init(self, expr_reg: Value, target_type: RType, reverse: bool) -> None:
792+
assert is_sequence_rprimitive(expr_reg.type), expr_reg
789793
builder = self.builder
790794
self.reverse = reverse
791795
# Define target to contain the expression, along with the index that will be used
792796
# for the for-loop. If we are inside of a generator function, spill these into the
793797
# environment class.
794798
self.expr_target = builder.maybe_spill(expr_reg)
799+
if is_immutable_rprimitive(expr_reg.type):
800+
# If the expression is an immutable type, we can load the length just once.
801+
self.length_reg = builder.maybe_spill(self.load_len(self.expr_target))
802+
else:
803+
# Otherwise, even if the length is known, we must recalculate the length
804+
# at every iteration for compatibility with python semantics.
805+
self.length_reg = None
795806
if not reverse:
796807
index_reg: Value = Integer(0, c_pyssize_t_rprimitive)
797808
else:
798-
index_reg = builder.builder.int_sub(self.load_len(self.expr_target), 1)
809+
if self.length_reg is not None:
810+
len_val = builder.read(self.length_reg)
811+
else:
812+
len_val = self.load_len(self.expr_target)
813+
index_reg = builder.builder.int_sub(len_val, 1)
799814
self.index_target = builder.maybe_spill_assignable(index_reg)
800815
self.target_type = target_type
801816

@@ -814,9 +829,13 @@ def gen_condition(self) -> None:
814829
second_check = BasicBlock()
815830
builder.add_bool_branch(comparison, second_check, self.loop_exit)
816831
builder.activate_block(second_check)
817-
# For compatibility with python semantics we recalculate the length
818-
# at every iteration.
819-
len_reg = self.load_len(self.expr_target)
832+
if self.length_reg is None:
833+
# For compatibility with python semantics we recalculate the length
834+
# at every iteration.
835+
len_reg = self.load_len(self.expr_target)
836+
else:
837+
# (unless input is immutable type).
838+
len_reg = builder.read(self.length_reg, line)
820839
comparison = builder.binary_op(builder.read(self.index_target, line), len_reg, "<", line)
821840
builder.add_bool_branch(comparison, self.body_block, self.loop_exit)
822841

mypyc/irbuild/specialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def translate_tuple_from_generator_call(
288288
"""Special case for simplest tuple creation from a generator.
289289
290290
For example:
291-
tuple(f(x) for x in some_list/some_tuple/some_str)
291+
tuple(f(x) for x in some_list/some_tuple/some_str/some_bytes)
292292
'translate_safe_generator_call()' would take care of other cases
293293
if this fails.
294294
"""

0 commit comments

Comments
 (0)