Skip to content

Commit 3115185

Browse files
authored
Enable customizable call modifier logic (#736)
1 parent 3a20d20 commit 3115185

File tree

3 files changed

+8
-20
lines changed

3 files changed

+8
-20
lines changed

src/codemodder/codemods/import_modifier_codemod.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABCMeta, abstractmethod
2-
from typing import Callable, Mapping
2+
from typing import Mapping
33

44
import libcst as cst
55
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
@@ -42,7 +42,7 @@ def update_simple_name(self, true_name, original_node, updated_node, new_args):
4242

4343

4444
class ImportModifierCodemod(LibcstResultTransformer, metaclass=ABCMeta):
45-
result_filter: Callable[[cst.CSTNode], bool] | None = None
45+
call_modifier: type[MappingImportedCallModifier] = MappingImportedCallModifier
4646

4747
@property
4848
def dependency(self) -> Dependency | None:
@@ -54,13 +54,12 @@ def mapping(self) -> Mapping[str, str]:
5454
pass
5555

5656
def transform_module_impl(self, tree: cst.Module) -> cst.Module:
57-
visitor = MappingImportedCallModifier(
57+
visitor = self.call_modifier(
5858
self.context,
5959
self.file_context,
6060
self.mapping,
6161
self.change_description,
6262
self.results,
63-
self.result_filter,
6463
)
6564
result_tree = visitor.transform_module(tree)
6665
self.file_context.codemod_changes.extend(visitor.changes_in_file)

src/codemodder/codemods/imported_call_modifier.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import abc
2-
from typing import Callable, Generic, Mapping, Sequence, Set, TypeVar, Union
2+
from typing import Generic, Mapping, Sequence, Set, TypeVar, Union
33

44
import libcst as cst
55
from libcst import matchers
66
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
7-
from libcst.metadata import PositionProvider
8-
from typing_extensions import override
7+
from libcst.metadata import ParentNodeProvider, PositionProvider
98

109
from codemodder.codemods.base_visitor import UtilsMixin
1110
from codemodder.codemods.utils_mixin import NameResolutionMixin
@@ -24,7 +23,7 @@ class ImportedCallModifier(
2423
UtilsMixin,
2524
metaclass=abc.ABCMeta,
2625
):
27-
METADATA_DEPENDENCIES = (PositionProvider,)
26+
METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider)
2827

2928
def __init__(
3029
self,
@@ -33,7 +32,6 @@ def __init__(
3332
matching_functions: FunctionMatchType,
3433
change_description: str,
3534
results: list[Result] | None = None,
36-
result_filter: Callable[[cst.CSTNode], bool] | None = None,
3735
):
3836
VisitorBasedCodemodCommand.__init__(self, codemod_context)
3937
self.line_exclude = file_context.line_exclude
@@ -43,15 +41,6 @@ def __init__(
4341
self.changes_in_file: list[Change] = []
4442
self.results = results
4543
self.file_context = file_context
46-
self.result_filter = result_filter
47-
48-
@override
49-
def filter_by_result(self, node: cst.CSTNode) -> bool:
50-
return (
51-
self.result_filter(node)
52-
if self.result_filter
53-
else super().filter_by_result(node)
54-
)
5544

5645
def updated_args(self, original_args: Sequence[cst.Arg]):
5746
return original_args
@@ -82,7 +71,7 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
8271
if self.filter_by_path_includes_or_excludes(pos_to_match):
8372
true_name = self.find_base_name(original_node.func)
8473
if (
85-
self._is_direct_call_from_imported_module(original_node)
74+
self.is_direct_call_from_imported_module(original_node)
8675
and true_name
8776
and true_name in self.matching_functions
8877
):

src/codemodder/codemods/utils_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def base_name_for_import(self, import_node, import_alias):
100100
# it is a from import
101101
return _get_name(import_node) + "." + get_full_name_for_node(import_alias.name)
102102

103-
def _is_direct_call_from_imported_module(
103+
def is_direct_call_from_imported_module(
104104
self, call: cst.Call
105105
) -> Optional[tuple[Union[cst.Import, cst.ImportFrom], cst.ImportAlias]]:
106106
for nodo in iterate_left_expressions(call):

0 commit comments

Comments
 (0)