diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 73282c94be4e..de84a7a56a22 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -26,6 +26,7 @@ freshen_all_functions_type_vars, freshen_function_type_vars, ) +from mypy.exprlength import get_static_expr_length from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments from mypy.literals import literal from mypy.maptype import map_instance_to_supertype @@ -4450,6 +4451,7 @@ def visit_index_with_type( # Allow special forms to be indexed and used to create union types return self.named_type("typing._SpecialForm") else: + self.static_index_range_check(left_type, e, index) result, method_type = self.check_method_call_by_name( "__getitem__", left_type, @@ -4472,6 +4474,46 @@ def min_tuple_length(self, left: TupleType) -> int: return left.length() - 1 + unpack.type.min_len return left.length() - 1 + def static_index_range_check(self, left_type: Type, e: IndexExpr, index: Expression) -> None: + if isinstance(left_type, Instance) and left_type.type.fullname in ( + "builtins.list", + "builtins.tuple", + "builtins.str", + "builtins.bytes", + ): + idx_val = None + # Try to extract integer literal index + if isinstance(index, IntExpr): + idx_val = index.value + elif isinstance(index, UnaryExpr): + if index.op == "-": + operand = index.expr + if isinstance(operand, IntExpr): + idx_val = -operand.value + elif index.op == "+": + operand = index.expr + if isinstance(operand, IntExpr): + idx_val = operand.value + # Could add more cases (e.g. LiteralType) if desired + if idx_val is not None: + length = get_static_expr_length(e.base) + if length is not None: + # For negative indices, Python counts from the end + check_idx = idx_val + if check_idx < 0: + check_idx += length + if not (0 <= check_idx < length): + name = "" + if isinstance(e.base, NameExpr): + name = e.base.name + self.chk.fail( + message_registry.SEQUENCE_INDEX_OUT_OF_RANGE.format( + name=name or "", length=length + ), + e, + code=message_registry.SEQUENCE_INDEX_OUT_OF_RANGE.code, + ) + def visit_tuple_index_helper(self, left: TupleType, n: int) -> Type | None: unpack_index = find_unpack_in_list(left.items) if unpack_index is None: diff --git a/mypy/errorcodes.py b/mypy/errorcodes.py index bcfdbf6edc2b..875ffbca1b6e 100644 --- a/mypy/errorcodes.py +++ b/mypy/errorcodes.py @@ -326,5 +326,9 @@ def __hash__(self) -> int: default_enabled=False, ) +INDEX_RANGE: Final[ErrorCode] = ErrorCode( + "index-range", "index out of statically known range", "Index Range", True +) + # This copy will not include any error codes defined later in the plugins. mypy_error_codes = error_codes.copy() diff --git a/mypy/exprlength.py b/mypy/exprlength.py new file mode 100644 index 000000000000..286a2da563eb --- /dev/null +++ b/mypy/exprlength.py @@ -0,0 +1,174 @@ +"""Static expression length analysis utilities for mypy. + +Provides helpers for statically determining the length of expressions, +when possible. +""" + +from typing import List, Optional, Tuple + +from mypy.nodes import ( + ARG_POS, + AssignmentStmt, + Block, + BytesExpr, + CallExpr, + ClassDef, + DictExpr, + Expression, + ExpressionStmt, + ForStmt, + FuncDef, + GeneratorExpr, + GlobalDecl, + IfStmt, + ListComprehension, + ListExpr, + MemberExpr, + NameExpr, + NonlocalDecl, + OverloadedFuncDef, + SetExpr, + StarExpr, + StrExpr, + TryStmt, + TupleExpr, + WhileStmt, + WithStmt, + is_IntExpr_list, +) + + +def get_static_expr_length(expr: Expression, context: Optional[Block] = None) -> Optional[int]: + """Try to statically determine the length of an expression. + + Returns the length if it can be determined at type-check time, + otherwise returns None. + + If context is provided, will attempt to resolve NameExpr/Var assignments. + """ + # NOTE: currently only used for indexing but could be extended to flag + # fun things like list.pop or to allow len([1, 2, 3]) to type check as Literal[3] + + # List, tuple literals (with possible star expressions) + if isinstance(expr, (ListExpr, TupleExpr)): + stars = [get_static_expr_length(i, context) for i in expr.items if isinstance(i, StarExpr)] + if None not in stars: + # if there are no star expressions, or we know the + # length of them, we know the length of the expression + other = sum(not isinstance(i, StarExpr) for i in expr.items) + return other + sum(star for star in stars if star is not None) + elif isinstance(expr, SetExpr): + # TODO: set expressions are more complicated, you need to know the + # actual value of each item in order to confidently state its length + pass + elif isinstance(expr, DictExpr): + # TODO: same as with sets, dicts are more complicated since you need + # to know the specific value of each key, and ensure they don't collide + pass + # String or bytes literal + elif isinstance(expr, (StrExpr, BytesExpr)): + return len(expr.value) + elif isinstance(expr, ListComprehension): + # If the generator's length is known, the list's length is known + return get_static_expr_length(expr.generator, context) + elif isinstance(expr, GeneratorExpr): + # If there is only one sequence and no conditions, and we know + # the sequence length, we know the max number of items yielded + # from the genexp and can pass that info forward + if len(expr.sequences) == 1 and len(expr.condlists) == 0: + return get_static_expr_length(expr.sequences[0], context) + # range() with constant arguments + elif isinstance(expr, CallExpr): + callee = expr.callee + if isinstance(callee, NameExpr) and callee.fullname == "builtins.range": + args = expr.args + if is_IntExpr_list(args) and all(kind == ARG_POS for kind in expr.arg_kinds): + if len(args) == 1: + # range(stop) + stop = args[0].value + return max(0, stop) + elif len(args) == 2: + # range(start, stop) + start, stop = args[0].value, args[1].value + return max(0, stop - start) + elif len(args) == 3: + # range(start, stop, step) + start, stop, step = args[0].value, args[1].value, args[2].value + if step == 0: + return None + n = (stop - start + (step - (1 if step > 0 else -1))) // step + return max(0, n) + # We have a big spaghetti monster of special case logic to resolve name expressions + elif isinstance(expr, NameExpr): + # Try to resolve the value of a local variable if possible + if context is None: + # Cannot resolve without context + return None + assignments: List[Tuple[AssignmentStmt, int]] = [] + + # Iterate thru all statements in the block + for stmt in context.body: + if isinstance( + stmt, + ( + IfStmt, + ForStmt, + WhileStmt, + TryStmt, + WithStmt, + FuncDef, + OverloadedFuncDef, + ClassDef, + ), + ): + # These statements complicate things and render the whole block useless + return None + elif isinstance(stmt, (GlobalDecl, NonlocalDecl)) and expr.name in stmt.names: + # We cannot assure the value of a global or nonlocal + return None + elif stmt.line >= expr.line: + # We can stop our analysis at the line where the name is used + break + # Check for any assignments + elif isinstance(stmt, AssignmentStmt): + # First, exit if any assignment has a rhs expression that + # could mutate the name + # TODO Write logic to recursively unwrap statements to see + # if any internal statements mess with our var + + # Iterate thru lvalues in the assignment + for idx, lval in enumerate(stmt.lvalues): + # Check if any of them matches our variable + if isinstance(lval, NameExpr) and lval.name == expr.name: + assignments.append((stmt, idx)) + elif isinstance(stmt, ExpressionStmt): + if isinstance(stmt.expr, CallExpr): + callee = stmt.expr.callee + for arg in stmt.expr.args: + if isinstance(arg, NameExpr) and arg.name == expr.name: + # our var was passed to a function as an input, + # it could be mutated now + return None + if ( + isinstance(callee, MemberExpr) + and isinstance(callee.expr, NameExpr) + and callee.expr.name == expr.name + ): + return None + + # For now, we only attempt to resolve the length + # when the name was only ever assigned to once + if len(assignments) != 1: + return None + stmt, idx = assignments[0] + rvalue = stmt.rvalue + # If single lvalue, just use rvalue + if len(stmt.lvalues) == 1: + return get_static_expr_length(rvalue, context) + # If multiple lvalues, try to extract the corresponding value + elif isinstance(rvalue, (TupleExpr, ListExpr)): + if len(rvalue.items) == len(stmt.lvalues): + return get_static_expr_length(rvalue.items[idx], context) + # Otherwise, cannot determine + # Could add more cases (e.g. dicts, sets) in the future + return None diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 09004322aee9..a6fcb9b4e267 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -370,3 +370,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage: TYPE_ALIAS_WITH_AWAIT_EXPRESSION: Final = ErrorMessage( "Await expression cannot be used within a type alias", codes.SYNTAX ) + +SEQUENCE_INDEX_OUT_OF_RANGE = ErrorMessage( + "Sequence index out of range: {name!r} only has {length} items", code=codes.INDEX_RANGE +) diff --git a/mypy/nodes.py b/mypy/nodes.py index 040f3fc28dce..c74016721c67 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -8,7 +8,7 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from enum import Enum, unique -from typing import TYPE_CHECKING, Any, Callable, Final, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Final, List, Optional, TypeVar, Union, cast from typing_extensions import TypeAlias as _TypeAlias, TypeGuard from mypy_extensions import trait @@ -1971,6 +1971,10 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_int_expr(self) +def is_IntExpr_list(items: List[Expression]) -> TypeGuard[List[IntExpr]]: + return all(isinstance(item, IntExpr) for item in items) + + # How mypy uses StrExpr and BytesExpr: # # b'x' -> BytesExpr diff --git a/mypyc/test-data/run-exceptions.test b/mypyc/test-data/run-exceptions.test index 1b180b933197..89ad2240431b 100644 --- a/mypyc/test-data/run-exceptions.test +++ b/mypyc/test-data/run-exceptions.test @@ -125,7 +125,7 @@ def g(b: bool) -> None: try: if b: x = [0] - x[1] + x[1] # type: ignore [index-range] else: raise Exception('hi') except: @@ -133,7 +133,7 @@ def g(b: bool) -> None: def r(x: int) -> None: if x == 0: - [0][1] + [0][1] # type: ignore [index-range] elif x == 1: raise Exception('hi') elif x == 2: @@ -263,7 +263,7 @@ Traceback (most recent call last): File "native.py", line 44, in i r(0) File "native.py", line 15, in r - [0][1] + [0][1] # type: ignore [index-range] IndexError: list index out of range == k == Traceback (most recent call last): @@ -281,7 +281,7 @@ Traceback (most recent call last): File "native.py", line 61, in k r(0) File "native.py", line 15, in r - [0][1] + [0][1] # type: ignore [index-range] IndexError: list index out of range == g == caught! @@ -330,7 +330,7 @@ Traceback (most recent call last): File "native.py", line 61, in k r(0) File "native.py", line 15, in r - [0][1] + [0][1] # type: ignore [index-range] IndexError: list index out of range == g == caught! @@ -371,7 +371,7 @@ def b(b1: int, b2: int) -> str: if b1 == 1: raise Exception('hi') elif b1 == 2: - [0][1] + [0][1] # type: ignore [index-range] elif b1 == 3: return 'try' except IndexError: diff --git a/mypyc/test-data/run-misc.test b/mypyc/test-data/run-misc.test index 129946a4c330..59a977f93088 100644 --- a/mypyc/test-data/run-misc.test +++ b/mypyc/test-data/run-misc.test @@ -25,7 +25,7 @@ def f(a: bool, b: bool) -> None: def g() -> None: try: - [0][1] + [0][1] # type: ignore [index-range] y = 1 except Exception: pass