Skip to content

Commit 5635724

Browse files
committed
feat: index range err code
1 parent a56adc8 commit 5635724

File tree

4 files changed

+205
-0
lines changed

4 files changed

+205
-0
lines changed

mypy/checkexpr.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
freshen_all_functions_type_vars,
2727
freshen_function_type_vars,
2828
)
29+
from mypy.exprlength import get_static_expr_length
2930
from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments
3031
from mypy.literals import literal
3132
from mypy.maptype import map_instance_to_supertype
@@ -4471,6 +4472,7 @@ def visit_index_with_type(
44714472
# Allow special forms to be indexed and used to create union types
44724473
return self.named_type("typing._SpecialForm")
44734474
else:
4475+
self.static_index_range_check(left_type, e, index)
44744476
result, method_type = self.check_method_call_by_name(
44754477
"__getitem__",
44764478
left_type,
@@ -4493,6 +4495,43 @@ def min_tuple_length(self, left: TupleType) -> int:
44934495
return left.length() - 1 + unpack.type.min_len
44944496
return left.length() - 1
44954497

4498+
def static_index_range_check(self, left_type: Type, e: IndexExpr, index: Expression):
4499+
if isinstance(left_type, Instance) and left_type.type.fullname in (
4500+
"builtins.list", "builtins.set", "builtins.dict", "builtins.str", "builtins.bytes"
4501+
):
4502+
idx_val = None
4503+
# Try to extract integer literal index
4504+
if isinstance(index, IntExpr):
4505+
idx_val = index.value
4506+
elif isinstance(index, UnaryExpr):
4507+
if index.op == "-":
4508+
operand = index.expr
4509+
if isinstance(operand, IntExpr):
4510+
idx_val = -operand.value
4511+
elif index.op == "+":
4512+
operand = index.expr
4513+
if isinstance(operand, IntExpr):
4514+
idx_val = operand.value
4515+
# Could add more cases (e.g. LiteralType) if desired
4516+
if idx_val is not None:
4517+
length = get_static_expr_length(e.base)
4518+
if length is not None:
4519+
# For negative indices, Python counts from the end
4520+
check_idx = idx_val
4521+
if check_idx < 0:
4522+
check_idx += length
4523+
if not (0 <= check_idx < length):
4524+
name = ""
4525+
if isinstance(e.base, NameExpr):
4526+
name = e.base.name
4527+
self.chk.fail(
4528+
message_registry.SEQUENCE_INDEX_OUT_OF_RANGE.format(
4529+
name=name or "<expr>", length=length
4530+
),
4531+
e,
4532+
code=message_registry.SEQUENCE_INDEX_OUT_OF_RANGE.code,
4533+
)
4534+
44964535
def visit_tuple_index_helper(self, left: TupleType, n: int) -> Type | None:
44974536
unpack_index = find_unpack_in_list(left.items)
44984537
if unpack_index is None:

mypy/errorcodes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,5 +322,12 @@ def __hash__(self) -> int:
322322
default_enabled=False,
323323
)
324324

325+
INDEX_RANGE = ErrorCode(
326+
"index-range",
327+
"index out of statically known range",
328+
"Index Range",
329+
True,
330+
)
331+
325332
# This copy will not include any error codes defined later in the plugins.
326333
mypy_error_codes = error_codes.copy()

mypy/exprlength.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""Static expression length analysis utilities for mypy.
2+
3+
Provides helpers for statically determining the length of expressions,
4+
when possible.
5+
"""
6+
7+
from typing import List, Optional, Tuple
8+
from mypy.nodes import (
9+
Expression,
10+
ExpressionStmt,
11+
ForStmt,
12+
GeneratorExpr,
13+
IfStmt,
14+
ListComprehension,
15+
ListExpr,
16+
TupleExpr,
17+
SetExpr,
18+
WhileStmt,
19+
DictExpr,
20+
MemberExpr,
21+
StrExpr,
22+
StarExpr,
23+
BytesExpr,
24+
CallExpr,
25+
NameExpr,
26+
TryStmt,
27+
Block,
28+
AssignmentStmt,
29+
is_IntExpr_list,
30+
)
31+
from mypy.nodes import WithStmt, FuncDef, OverloadedFuncDef, ClassDef, GlobalDecl, NonlocalDecl
32+
from mypy.nodes import ARG_POS
33+
34+
def get_static_expr_length(expr: Expression, context: Optional[Block] = None) -> Optional[int]:
35+
"""Try to statically determine the length of an expression.
36+
37+
Returns the length if it can be determined at type-check time,
38+
otherwise returns None.
39+
40+
If context is provided, will attempt to resolve NameExpr/Var assignments.
41+
"""
42+
# NOTE: currently only used for indexing but could be extended to flag
43+
# fun things like list.pop or to allow len([1, 2, 3]) to type check as Literal[3]
44+
45+
# List, tuple literals (with possible star expressions)
46+
if isinstance(expr, (ListExpr, TupleExpr)):
47+
# if there are no star expressions, or we know the length of them,
48+
# we know the length of the expression
49+
stars = [get_static_expr_length(i, context) for i in expr.items if isinstance(i, StarExpr)]
50+
other = sum(not isinstance(i, StarExpr) for i in expr.items)
51+
return other + sum(star for star in stars if star is not None)
52+
elif isinstance(expr, SetExpr):
53+
# TODO: set expressions are more complicated, you need to know the
54+
# actual value of each item in order to confidently state its length
55+
pass
56+
elif isinstance(expr, DictExpr):
57+
# TODO: same as with sets, dicts are more complicated since you need
58+
# to know the specific value of each key, and ensure they don't collide
59+
pass
60+
# String or bytes literal
61+
elif isinstance(expr, (StrExpr, BytesExpr)):
62+
return len(expr.value)
63+
elif isinstance(expr, ListComprehension):
64+
# If the generator's length is known, the list's length is known
65+
return get_static_expr_length(expr.generator, context)
66+
elif isinstance(expr, GeneratorExpr):
67+
# If there is only one sequence and no conditions, and we know
68+
# the sequence length, we know the max number of items yielded
69+
# from the genexp and can pass that info forward
70+
if len(expr.sequences) == 1 and len(expr.condlists) == 0:
71+
return get_static_expr_length(expr.sequences[0], context)
72+
# range() with constant arguments
73+
elif isinstance(expr, CallExpr):
74+
callee = expr.callee
75+
if isinstance(callee, NameExpr) and callee.fullname == "builtins.range":
76+
args = expr.args
77+
if is_IntExpr_list(args) and all(kind == ARG_POS for kind in expr.arg_kinds):
78+
if len(args) == 1:
79+
# range(stop)
80+
stop = args[0].value
81+
return max(0, stop)
82+
elif len(args) == 2:
83+
# range(start, stop)
84+
start, stop = args[0].value, args[1].value
85+
return max(0, stop - start)
86+
elif len(args) == 3:
87+
# range(start, stop, step)
88+
start, stop, step = args[0].value, args[1].value, args[2].value
89+
if step == 0:
90+
return None
91+
n = (stop - start + (step - (1 if step > 0 else -1))) // step
92+
return max(0, n)
93+
# We have a big spaghetti monster of special case logic to resolve name expressions
94+
elif isinstance(expr, NameExpr):
95+
# Try to resolve the value of a local variable if possible
96+
if context is None:
97+
# Cannot resolve without context
98+
return None
99+
assignments: List[Tuple[AssignmentStmt, int]] = []
100+
101+
# Iterate thru all statements in the block
102+
for stmt in context.body:
103+
if isinstance(stmt, (IfStmt, ForStmt, WhileStmt, TryStmt, WithStmt, FuncDef, OverloadedFuncDef, ClassDef)):
104+
# These statements complicate things and render the whole block useless
105+
return None
106+
elif isinstance(stmt, (GlobalDecl, NonlocalDecl)) and expr.name in stmt.names:
107+
# We cannot assure the value of a global or nonlocal
108+
return None
109+
elif stmt.line >= expr.line:
110+
# We can stop our analysis at the line where the name is used
111+
break
112+
# Check for any assignments
113+
elif isinstance(stmt, AssignmentStmt):
114+
# First, exit if any assignment has a rhs expression that
115+
# could mutate the name
116+
# TODO Write logic to recursively unwrap statements to see
117+
# if any internal statements mess with our var
118+
119+
# Iterate thru lvalues in the assignment
120+
for idx, lval in enumerate(stmt.lvalues):
121+
# Check if any of them matches our variable
122+
if isinstance(lval, NameExpr) and lval.name == expr.name:
123+
assignments.append((stmt, idx))
124+
elif isinstance(stmt, ExpressionStmt):
125+
if isinstance(stmt.expr, CallExpr):
126+
callee = stmt.expr.callee
127+
for arg in stmt.expr.args:
128+
if isinstance(arg, NameExpr) and arg.name == expr.name:
129+
# our var was passed to a function as an input,
130+
# it could be mutated now
131+
return None
132+
if (
133+
isinstance(callee, MemberExpr)
134+
and isinstance(callee.expr, NameExpr)
135+
and callee.expr.name == expr.name
136+
):
137+
return None
138+
139+
# For now, we only attempt to resolve the length
140+
# when the name was only ever assigned to once
141+
if len(assignments) != 1:
142+
return None
143+
stmt, idx = assignments[0]
144+
rvalue = stmt.rvalue
145+
# If single lvalue, just use rvalue
146+
if len(stmt.lvalues) == 1:
147+
return get_static_expr_length(rvalue, context)
148+
# If multiple lvalues, try to extract the corresponding value
149+
elif isinstance(rvalue, (TupleExpr, ListExpr)):
150+
if len(rvalue.items) == len(stmt.lvalues):
151+
return get_static_expr_length(rvalue.items[idx], context)
152+
# Otherwise, cannot determine
153+
# Could add more cases (e.g. dicts, sets) in the future
154+
return None

mypy/message_registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,8 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
367367
TYPE_ALIAS_WITH_AWAIT_EXPRESSION: Final = ErrorMessage(
368368
"Await expression cannot be used within a type alias", codes.SYNTAX
369369
)
370+
371+
SEQUENCE_INDEX_OUT_OF_RANGE = ErrorMessage(
372+
"Sequence index out of range: {name!r} only has {length} items",
373+
code=codes.INDEX_RANGE,
374+
)

0 commit comments

Comments
 (0)