Skip to content

Commit fbb3d3f

Browse files
authored
Use import call modifier pattern for url-sandbox (#750)
1 parent a37b023 commit fbb3d3f

File tree

4 files changed

+71
-170
lines changed

4 files changed

+71
-170
lines changed

src/codemodder/codemods/import_modifier_codemod.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from codemodder.codemods.api import LibcstResultTransformer
88
from codemodder.codemods.imported_call_modifier import ImportedCallModifier
9-
from codemodder.dependency import Dependency
9+
from codemodder.dependency import Dependency, Security
1010

1111

1212
class MappingImportedCallModifier(ImportedCallModifier[Mapping[str, str]]):
@@ -15,7 +15,7 @@ def update_attribute(self, true_name, original_node, updated_node, new_args):
1515
return updated_node
1616

1717
import_name = self.matching_functions[true_name]
18-
AddImportsVisitor.add_needed_import(self.context, import_name)
18+
self.add_import(import_name)
1919
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node)
2020
return updated_node.with_changes(
2121
args=new_args,
@@ -30,7 +30,7 @@ def update_simple_name(self, true_name, original_node, updated_node, new_args):
3030
return updated_node
3131

3232
import_name = self.matching_functions[true_name]
33-
AddImportsVisitor.add_needed_import(self.context, import_name)
33+
self.add_import(import_name)
3434
RemoveImportsVisitor.remove_unused_import_by_node(self.context, original_node)
3535
return updated_node.with_changes(
3636
args=new_args,
@@ -40,6 +40,9 @@ def update_simple_name(self, true_name, original_node, updated_node, new_args):
4040
),
4141
)
4242

43+
def add_import(self, import_name):
44+
AddImportsVisitor.add_needed_import(self.context, import_name)
45+
4346

4447
class ImportModifierCodemod(LibcstResultTransformer, metaclass=ABCMeta):
4548
call_modifier: type[MappingImportedCallModifier] = MappingImportedCallModifier
@@ -67,3 +70,14 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
6770
self.add_dependency(dependency)
6871

6972
return result_tree
73+
74+
75+
class SecurityCallModifier(MappingImportedCallModifier):
76+
def add_import(self, import_name: str) -> None:
77+
AddImportsVisitor.add_needed_import(
78+
self.context, module=Security.requirement.name, obj=import_name
79+
)
80+
81+
82+
class SecurityImportModifierCodemod(ImportModifierCodemod, metaclass=ABCMeta):
83+
call_modifier: type[SecurityCallModifier] = SecurityCallModifier

src/codemodder/codemods/imported_call_modifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def update_simple_name(
6868
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
6969
pos_to_match = self.node_position(original_node)
7070
line_number = pos_to_match.start.line
71-
if self.filter_by_path_includes_or_excludes(pos_to_match):
71+
if self.node_is_selected(
72+
original_node
73+
) and self.filter_by_path_includes_or_excludes(pos_to_match):
7274
true_name = self.find_base_name(original_node.func)
7375
if (
7476
self.is_direct_call_from_imported_module(original_node)

src/core_codemods/url_sandbox.py

Lines changed: 29 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -1,156 +1,26 @@
1-
from typing import List, Optional, Union
1+
from functools import cached_property
22

3-
import libcst as cst
4-
from libcst import CSTNode, matchers
5-
from libcst.codemod import CodemodContext, ContextAwareVisitor
6-
from libcst.codemod.visitors import AddImportsVisitor, ImportItem
7-
from libcst.metadata import PositionProvider, ScopeProvider
8-
9-
from codemodder.codemods.base_visitor import UtilsMixin
10-
from codemodder.codemods.libcst_transformer import (
11-
LibcstResultTransformer,
12-
LibcstTransformerPipeline,
13-
)
3+
from codemodder.codemods.import_modifier_codemod import SecurityImportModifierCodemod
4+
from codemodder.codemods.libcst_transformer import LibcstTransformerPipeline
145
from codemodder.codemods.semgrep import SemgrepRuleDetector
15-
from codemodder.codemods.transformations.remove_unused_imports import (
16-
RemoveUnusedImportsCodemod,
17-
)
18-
from codemodder.codemods.utils import ReplaceNodes
19-
from codemodder.codetf import Change
20-
from codemodder.dependency import Security
21-
from codemodder.file_context import FileContext
6+
from codemodder.dependency import Dependency, Security
227
from core_codemods.api import CoreCodemod, Metadata, Reference, ReviewGuidance
238

24-
replacement_import = "safe_requests"
25-
269

27-
class UrlSandboxTransformer(LibcstResultTransformer):
10+
class UrlSandboxTransformer(SecurityImportModifierCodemod):
2811
change_description = "Switch use of requests for security.safe_requests"
29-
METADATA_DEPENDENCIES = (PositionProvider, ScopeProvider)
30-
adds_dependency = True
31-
32-
def transform_module_impl(self, tree: cst.Module) -> cst.Module:
33-
# we first gather all the nodes we want to change together with their replacements
34-
find_requests_visitor = FindRequestCallsAndImports(
35-
self.context,
36-
self.file_context,
37-
self.file_context.results,
38-
)
39-
tree.visit(find_requests_visitor)
40-
if find_requests_visitor.nodes_to_change:
41-
self.file_context.codemod_changes.extend(
42-
find_requests_visitor.changes_in_file
43-
)
44-
new_tree = tree.visit(ReplaceNodes(find_requests_visitor.nodes_to_change))
45-
self.add_dependency(Security)
46-
# if it finds any request.get(...), try to remove the imports
47-
if any(
48-
(
49-
matchers.matches(n, matchers.Call())
50-
for n in find_requests_visitor.nodes_to_change
51-
)
52-
):
53-
new_tree = AddImportsVisitor(
54-
self.context,
55-
[ImportItem(Security.name, replacement_import, None, 0)],
56-
).transform_module(new_tree)
57-
new_tree = RemoveUnusedImportsCodemod(self.context).transform_module(
58-
new_tree
59-
)
60-
return new_tree
61-
return tree
62-
63-
64-
class FindRequestCallsAndImports(ContextAwareVisitor, UtilsMixin):
65-
METADATA_DEPENDENCIES = (ScopeProvider,)
66-
67-
def __init__(
68-
self, codemod_context: CodemodContext, file_context: FileContext, results
69-
):
70-
self.nodes_to_change: dict[
71-
cst.CSTNode, Union[cst.CSTNode, cst.FlattenSentinel, cst.RemovalSentinel]
72-
] = {}
73-
self.changes_in_file: List[Change] = []
74-
self.file_context = file_context
75-
ContextAwareVisitor.__init__(self, codemod_context)
76-
UtilsMixin.__init__(
77-
self,
78-
results=results,
79-
line_include=file_context.line_include,
80-
line_exclude=file_context.line_exclude,
81-
)
8212

83-
def leave_Call(self, original_node: cst.Call):
84-
if not self.node_is_selected(original_node):
85-
return
13+
@cached_property
14+
def mapping(self) -> dict[str, str]:
15+
"""Build a mapping of functions to their safe_requests imports"""
16+
_matching_functions: dict[str, str] = {
17+
"requests.get": "safe_requests",
18+
}
19+
return _matching_functions
8620

87-
line_number = self.node_position(original_node).start.line
88-
match original_node.args[0].value:
89-
case cst.SimpleString():
90-
return
91-
92-
match original_node:
93-
# case get(...)
94-
case cst.Call(func=cst.Name()):
95-
# find if get(...) comes from an from requests import get
96-
match self.find_single_assignment(original_node):
97-
case cst.ImportFrom() as node:
98-
self.nodes_to_change.update(
99-
{
100-
node: cst.ImportFrom(
101-
module=cst.Attribute(
102-
value=cst.Name(Security.name),
103-
attr=cst.Name(replacement_import),
104-
),
105-
names=node.names,
106-
)
107-
}
108-
)
109-
self.changes_in_file.append(
110-
Change(
111-
lineNumber=line_number,
112-
description=UrlSandboxTransformer.change_description,
113-
findings=self.file_context.get_findings_for_location(
114-
line_number
115-
),
116-
)
117-
)
118-
119-
# case req.get(...)
120-
case _:
121-
self.nodes_to_change.update(
122-
{
123-
original_node: cst.Call(
124-
func=cst.parse_expression(replacement_import + ".get"),
125-
args=original_node.args,
126-
)
127-
}
128-
)
129-
self.changes_in_file.append(
130-
Change(
131-
lineNumber=line_number,
132-
description=UrlSandboxTransformer.change_description,
133-
findings=self.file_context.get_findings_for_location(
134-
line_number
135-
),
136-
)
137-
)
138-
139-
def _find_assignments(self, node: CSTNode):
140-
"""
141-
Given a MetadataWrapper and a CSTNode representing an access, find all the possible assignments that it refers.
142-
"""
143-
scope = self.get_metadata(ScopeProvider, node)
144-
return next(iter(scope.accesses[node]))._Access__assignments
145-
146-
def find_single_assignment(self, node: CSTNode) -> Optional[CSTNode]:
147-
"""
148-
Given a MetadataWrapper and a CSTNode representing an access, find if there is a single assignment that it refers to.
149-
"""
150-
assignments = self._find_assignments(node)
151-
if len(assignments) == 1:
152-
return next(iter(assignments)).node
153-
return None
21+
@property
22+
def dependency(self) -> Dependency:
23+
return Security
15424

15525

15626
UrlSandbox = CoreCodemod(
@@ -174,20 +44,20 @@ def find_single_assignment(self, node: CSTNode) -> Optional[CSTNode]:
17444
),
17545
detector=SemgrepRuleDetector(
17646
"""
177-
rules:
178-
- id: url-sandbox
179-
message: Unbounded URL creation
180-
severity: WARNING
181-
languages:
182-
- python
183-
pattern-either:
184-
- patterns:
185-
- pattern: requests.get(...)
186-
- pattern-not: requests.get("...")
187-
- pattern-inside: |
188-
import requests
189-
...
190-
"""
47+
rules:
48+
- id: url-sandbox
49+
message: Unbounded URL creation
50+
severity: WARNING
51+
languages:
52+
- python
53+
pattern-either:
54+
- patterns:
55+
- pattern: requests.get(...)
56+
- pattern-not: requests.get("...")
57+
- pattern-inside: |
58+
import requests
59+
...
60+
"""
19161
),
19262
transformer=LibcstTransformerPipeline(UrlSandboxTransformer),
19363
)

tests/codemods/test_url_sandbox.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def test_from_requests(self, add_dependency, tmpdir):
4040
var = "hello"
4141
"""
4242
expected = """
43-
from security.safe_requests import get
43+
from security import safe_requests
4444
4545
url = input()
46-
get(url)
46+
safe_requests.get(url)
4747
var = "hello"
4848
"""
4949
self.run_and_assert(tmpdir, input_code, expected)
@@ -160,10 +160,10 @@ def test_from_requests_with_alias(self, add_dependency, tmpdir):
160160
var = "hello"
161161
"""
162162
expected = """
163-
from security.safe_requests import get as got
163+
from security import safe_requests
164164
165165
url = input()
166-
got(url)
166+
safe_requests.get(url)
167167
var = "hello"
168168
"""
169169
self.run_and_assert(tmpdir, input_code, expected)
@@ -198,6 +198,24 @@ def test_ignore_hardcoded(self, _, tmpdir):
198198

199199
self.run_and_assert(tmpdir, input_code, expected)
200200

201+
def test_ignore_hardcoded_but_not_all(self, _, tmpdir):
202+
input_code = """
203+
import requests
204+
205+
requests.get("www.google.com")
206+
url = input()
207+
requests.get(url)
208+
"""
209+
expected = """
210+
import requests
211+
from security import safe_requests
212+
213+
requests.get("www.google.com")
214+
url = input()
215+
safe_requests.get(url)
216+
"""
217+
self.run_and_assert(tmpdir, input_code, expected)
218+
201219
def test_ignore_hardcoded_from_global_variable(self, _, tmpdir):
202220
expected = (
203221
input_code
@@ -262,9 +280,6 @@ def foo():
262280

263281
self.run_and_assert(tmpdir, input_code, expected)
264282

265-
@pytest.mark.xfail(
266-
reason="Does not properly handle 'from' imports with multiple names"
267-
)
268283
def test_multiple_imports(self, add_dependency, tmpdir):
269284
input_code = """
270285
from requests import Response, Timeout, get

0 commit comments

Comments
 (0)