|
| 1 | +from textwrap import dedent |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import libcst as cst |
| 5 | +from libcst import matchers |
| 6 | + |
1 | 7 | from codemodder.codemods.libcst_transformer import ( |
2 | 8 | LibcstResultTransformer, |
3 | 9 | LibcstTransformerPipeline, |
4 | 10 | ) |
5 | | -from codemodder.codemods.semgrep import SemgrepRuleDetector |
6 | | -from codemodder.codemods.utils_mixin import NameResolutionMixin |
| 11 | +from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin |
| 12 | +from codemodder.utils.utils import clean_simplestring |
7 | 13 | from core_codemods.api import CoreCodemod, Metadata, Reference, ReviewGuidance |
8 | 14 |
|
9 | 15 |
|
10 | | -class TempfileMktempTransformer(LibcstResultTransformer, NameResolutionMixin): |
| 16 | +class TempfileMktempTransformer( |
| 17 | + LibcstResultTransformer, NameAndAncestorResolutionMixin |
| 18 | +): |
11 | 19 | change_description = "Replaces `tempfile.mktemp` with `tempfile.mkstemp`." |
12 | 20 | _module_name = "tempfile" |
13 | 21 |
|
14 | | - def on_result_found(self, original_node, updated_node): |
15 | | - maybe_name = self.get_aliased_prefix_name(original_node, self._module_name) |
16 | | - if (maybe_name := maybe_name or self._module_name) == self._module_name: |
17 | | - self.add_needed_import(self._module_name) |
18 | | - self.remove_unused_import(original_node) |
19 | | - return self.update_call_target(updated_node, maybe_name, "mkstemp") |
| 22 | + def leave_SimpleStatementLine(self, original_node, updated_node): |
| 23 | + match original_node: |
| 24 | + case cst.SimpleStatementLine(body=[bsstmt]): |
| 25 | + return self.check_mktemp(original_node, bsstmt) |
| 26 | + return updated_node |
| 27 | + |
| 28 | + def check_mktemp( |
| 29 | + self, original_node: cst.SimpleStatementLine, bsstmt: cst.BaseSmallStatement |
| 30 | + ) -> cst.SimpleStatementLine | cst.FlattenSentinel: |
| 31 | + if maybe_tuple := self._is_assigned_to_mktemp(bsstmt): # type: ignore |
| 32 | + assign_name, call = maybe_tuple |
| 33 | + return self.report_and_change(call, assign_name) |
| 34 | + if maybe_tuple := self._mktemp_is_sink(bsstmt): |
| 35 | + wrapper_func_name, call = maybe_tuple |
| 36 | + return self.report_and_change(call, wrapper_func_name, assignment=False) |
| 37 | + return original_node |
| 38 | + |
| 39 | + def report_and_change( |
| 40 | + self, node: cst.Call, name: cst.Name, assignment=True |
| 41 | + ) -> cst.FlattenSentinel: |
| 42 | + self.report_change(node) |
| 43 | + self.add_needed_import(self._module_name) |
| 44 | + self.remove_unused_import(node) |
| 45 | + with_block = ( |
| 46 | + f"{name.value} = tf.name" if assignment else f"{name.value}(tf.name)" |
| 47 | + ) |
| 48 | + new_stmt = dedent( |
| 49 | + f""" |
| 50 | + with tempfile.NamedTemporaryFile({self._make_args(node)}) as tf: |
| 51 | + {with_block} |
| 52 | + """ |
| 53 | + ).rstrip() |
| 54 | + return cst.FlattenSentinel( |
| 55 | + [ |
| 56 | + cst.parse_statement(new_stmt), |
| 57 | + ] |
| 58 | + ) |
| 59 | + |
| 60 | + def _make_args(self, node: cst.Call) -> str: |
| 61 | + """Convert args passed to tempfile.mktemp() to string for args to tempfile.NamedTemporaryFile""" |
| 62 | + |
| 63 | + default = "delete=False" |
| 64 | + if not node.args: |
| 65 | + return default |
| 66 | + new_args = "" |
| 67 | + arg_keys = ("suffix", "prefix", "dir") |
| 68 | + for idx, arg in enumerate(node.args): |
| 69 | + cst.ensure_type(val := arg.value, cst.SimpleString) |
| 70 | + new_args += f'{arg_keys[idx]}="{clean_simplestring(val)}", ' |
| 71 | + return f"{new_args}{default}" |
| 72 | + |
| 73 | + def _is_assigned_to_mktemp( |
| 74 | + self, bsstmt: cst.BaseSmallStatement |
| 75 | + ) -> Optional[tuple[cst.Name, cst.Call]]: |
| 76 | + match bsstmt: |
| 77 | + case cst.Assign(value=value, targets=targets): |
| 78 | + maybe_value = self._is_mktemp_call(value) # type: ignore |
| 79 | + if maybe_value and all( |
| 80 | + map( |
| 81 | + lambda t: matchers.matches( |
| 82 | + t, matchers.AssignTarget(target=matchers.Name()) |
| 83 | + ), |
| 84 | + targets, # type: ignore |
| 85 | + ) |
| 86 | + ): |
| 87 | + # # Todo: handle multiple potential targets |
| 88 | + return (targets[0].target, maybe_value) |
| 89 | + case cst.AnnAssign(target=target, value=value): |
| 90 | + maybe_value = self._is_mktemp_call(value) # type: ignore |
| 91 | + if maybe_value and isinstance(target, cst.Name): # type: ignore |
| 92 | + return (target, maybe_value) |
| 93 | + return None |
| 94 | + |
| 95 | + def _is_mktemp_call(self, value) -> Optional[cst.Call]: |
| 96 | + match value: |
| 97 | + case cst.Call() if self.find_base_name(value.func) == "tempfile.mktemp": |
| 98 | + return value |
| 99 | + return None |
| 100 | + |
| 101 | + def _mktemp_is_sink( |
| 102 | + self, bsstmt: cst.BaseSmallStatement |
| 103 | + ) -> Optional[tuple[cst.Name, cst.Call]]: |
| 104 | + match bsstmt: |
| 105 | + case cst.Expr(value=cst.Call() as call): |
| 106 | + if not (args := call.args): |
| 107 | + return None |
| 108 | + |
| 109 | + # todo: handle more complex cases of mktemp in different arg pos |
| 110 | + match first_arg_call := args[0].value: |
| 111 | + case cst.Call(): |
| 112 | + if maybe_value := self._is_mktemp_call(first_arg_call): # type: ignore |
| 113 | + wrapper_func = call.func |
| 114 | + return (wrapper_func, maybe_value) |
| 115 | + return None |
20 | 116 |
|
21 | 117 |
|
22 | 118 | TempfileMktemp = CoreCodemod( |
23 | 119 | metadata=Metadata( |
24 | 120 | name="secure-tempfile", |
25 | 121 | summary="Upgrade and Secure Temp File Creation", |
26 | | - review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW, |
| 122 | + review_guidance=ReviewGuidance.MERGE_AFTER_REVIEW, |
27 | 123 | references=[ |
28 | 124 | Reference( |
29 | 125 | url="https://docs.python.org/3/library/tempfile.html#tempfile.mktemp" |
30 | 126 | ), |
31 | 127 | ], |
32 | 128 | ), |
33 | | - detector=SemgrepRuleDetector( |
34 | | - """ |
35 | | - rules: |
36 | | - - patterns: |
37 | | - - pattern: tempfile.mktemp(...) |
38 | | - - pattern-inside: | |
39 | | - import tempfile |
40 | | - ... |
41 | | - """ |
42 | | - ), |
43 | 129 | transformer=LibcstTransformerPipeline(TempfileMktempTransformer), |
44 | 130 | ) |
0 commit comments