|
| 1 | +"""Static expression length analysis utilities for mypy. |
| 2 | +
|
| 3 | +Provides helpers for statically determining the length of expressions, |
| 4 | +when possible. |
| 5 | +""" |
| 6 | + |
| 7 | +from typing import List, Optional, Tuple |
| 8 | +from mypy.nodes import ( |
| 9 | + Expression, |
| 10 | + ExpressionStmt, |
| 11 | + ForStmt, |
| 12 | + GeneratorExpr, |
| 13 | + IfStmt, |
| 14 | + ListComprehension, |
| 15 | + ListExpr, |
| 16 | + TupleExpr, |
| 17 | + SetExpr, |
| 18 | + WhileStmt, |
| 19 | + DictExpr, |
| 20 | + MemberExpr, |
| 21 | + StrExpr, |
| 22 | + StarExpr, |
| 23 | + BytesExpr, |
| 24 | + CallExpr, |
| 25 | + NameExpr, |
| 26 | + TryStmt, |
| 27 | + Block, |
| 28 | + AssignmentStmt, |
| 29 | + is_IntExpr_list, |
| 30 | +) |
| 31 | +from mypy.nodes import WithStmt, FuncDef, OverloadedFuncDef, ClassDef, GlobalDecl, NonlocalDecl |
| 32 | +from mypy.nodes import ARG_POS |
| 33 | + |
| 34 | +def get_static_expr_length(expr: Expression, context: Optional[Block] = None) -> Optional[int]: |
| 35 | + """Try to statically determine the length of an expression. |
| 36 | +
|
| 37 | + Returns the length if it can be determined at type-check time, |
| 38 | + otherwise returns None. |
| 39 | +
|
| 40 | + If context is provided, will attempt to resolve NameExpr/Var assignments. |
| 41 | + """ |
| 42 | + # NOTE: currently only used for indexing but could be extended to flag |
| 43 | + # fun things like list.pop or to allow len([1, 2, 3]) to type check as Literal[3] |
| 44 | + |
| 45 | + # List, tuple literals (with possible star expressions) |
| 46 | + if isinstance(expr, (ListExpr, TupleExpr)): |
| 47 | + # if there are no star expressions, or we know the length of them, |
| 48 | + # we know the length of the expression |
| 49 | + stars = [get_static_expr_length(i, context) for i in expr.items if isinstance(i, StarExpr)] |
| 50 | + other = sum(not isinstance(i, StarExpr) for i in expr.items) |
| 51 | + return other + sum(star for star in stars if star is not None) |
| 52 | + elif isinstance(expr, SetExpr): |
| 53 | + # TODO: set expressions are more complicated, you need to know the |
| 54 | + # actual value of each item in order to confidently state its length |
| 55 | + pass |
| 56 | + elif isinstance(expr, DictExpr): |
| 57 | + # TODO: same as with sets, dicts are more complicated since you need |
| 58 | + # to know the specific value of each key, and ensure they don't collide |
| 59 | + pass |
| 60 | + # String or bytes literal |
| 61 | + elif isinstance(expr, (StrExpr, BytesExpr)): |
| 62 | + return len(expr.value) |
| 63 | + elif isinstance(expr, ListComprehension): |
| 64 | + # If the generator's length is known, the list's length is known |
| 65 | + return get_static_expr_length(expr.generator, context) |
| 66 | + elif isinstance(expr, GeneratorExpr): |
| 67 | + # If there is only one sequence and no conditions, and we know |
| 68 | + # the sequence length, we know the max number of items yielded |
| 69 | + # from the genexp and can pass that info forward |
| 70 | + if len(expr.sequences) == 1 and len(expr.condlists) == 0: |
| 71 | + return get_static_expr_length(expr.sequences[0], context) |
| 72 | + # range() with constant arguments |
| 73 | + elif isinstance(expr, CallExpr): |
| 74 | + callee = expr.callee |
| 75 | + if isinstance(callee, NameExpr) and callee.fullname == "builtins.range": |
| 76 | + args = expr.args |
| 77 | + if is_IntExpr_list(args) and all(kind == ARG_POS for kind in expr.arg_kinds): |
| 78 | + if len(args) == 1: |
| 79 | + # range(stop) |
| 80 | + stop = args[0].value |
| 81 | + return max(0, stop) |
| 82 | + elif len(args) == 2: |
| 83 | + # range(start, stop) |
| 84 | + start, stop = args[0].value, args[1].value |
| 85 | + return max(0, stop - start) |
| 86 | + elif len(args) == 3: |
| 87 | + # range(start, stop, step) |
| 88 | + start, stop, step = args[0].value, args[1].value, args[2].value |
| 89 | + if step == 0: |
| 90 | + return None |
| 91 | + n = (stop - start + (step - (1 if step > 0 else -1))) // step |
| 92 | + return max(0, n) |
| 93 | + # We have a big spaghetti monster of special case logic to resolve name expressions |
| 94 | + elif isinstance(expr, NameExpr): |
| 95 | + # Try to resolve the value of a local variable if possible |
| 96 | + if context is None: |
| 97 | + # Cannot resolve without context |
| 98 | + return None |
| 99 | + assignments: List[Tuple[AssignmentStmt, int]] = [] |
| 100 | + |
| 101 | + # Iterate thru all statements in the block |
| 102 | + for stmt in context.body: |
| 103 | + if isinstance(stmt, (IfStmt, ForStmt, WhileStmt, TryStmt, WithStmt, FuncDef, OverloadedFuncDef, ClassDef)): |
| 104 | + # These statements complicate things and render the whole block useless |
| 105 | + return None |
| 106 | + elif isinstance(stmt, (GlobalDecl, NonlocalDecl)) and expr.name in stmt.names: |
| 107 | + # We cannot assure the value of a global or nonlocal |
| 108 | + return None |
| 109 | + elif stmt.line >= expr.line: |
| 110 | + # We can stop our analysis at the line where the name is used |
| 111 | + break |
| 112 | + # Check for any assignments |
| 113 | + elif isinstance(stmt, AssignmentStmt): |
| 114 | + # First, exit if any assignment has a rhs expression that |
| 115 | + # could mutate the name |
| 116 | + # TODO Write logic to recursively unwrap statements to see |
| 117 | + # if any internal statements mess with our var |
| 118 | + |
| 119 | + # Iterate thru lvalues in the assignment |
| 120 | + for idx, lval in enumerate(stmt.lvalues): |
| 121 | + # Check if any of them matches our variable |
| 122 | + if isinstance(lval, NameExpr) and lval.name == expr.name: |
| 123 | + assignments.append((stmt, idx)) |
| 124 | + elif isinstance(stmt, ExpressionStmt): |
| 125 | + if isinstance(stmt.expr, CallExpr): |
| 126 | + callee = stmt.expr.callee |
| 127 | + for arg in stmt.expr.args: |
| 128 | + if isinstance(arg, NameExpr) and arg.name == expr.name: |
| 129 | + # our var was passed to a function as an input, |
| 130 | + # it could be mutated now |
| 131 | + return None |
| 132 | + if ( |
| 133 | + isinstance(callee, MemberExpr) |
| 134 | + and isinstance(callee.expr, NameExpr) |
| 135 | + and callee.expr.name == expr.name |
| 136 | + ): |
| 137 | + return None |
| 138 | + |
| 139 | + # For now, we only attempt to resolve the length |
| 140 | + # when the name was only ever assigned to once |
| 141 | + if len(assignments) != 1: |
| 142 | + return None |
| 143 | + stmt, idx = assignments[0] |
| 144 | + rvalue = stmt.rvalue |
| 145 | + # If single lvalue, just use rvalue |
| 146 | + if len(stmt.lvalues) == 1: |
| 147 | + return get_static_expr_length(rvalue, context) |
| 148 | + # If multiple lvalues, try to extract the corresponding value |
| 149 | + elif isinstance(rvalue, (TupleExpr, ListExpr)): |
| 150 | + if len(rvalue.items) == len(stmt.lvalues): |
| 151 | + return get_static_expr_length(rvalue.items[idx], context) |
| 152 | + # Otherwise, cannot determine |
| 153 | + # Could add more cases (e.g. dicts, sets) in the future |
| 154 | + return None |
0 commit comments