55
66import mock
77
8+ from codemodder import registry
9+ from codemodder .codemods .api import BaseCodemod
810from codemodder .context import CodemodExecutionContext
911from codemodder .diff import create_diff
1012from codemodder .providers import load_providers
11- from codemodder .registry import CodemodCollection , CodemodRegistry
12- from codemodder .semgrep import run as semgrep_run
13+
14+
15+ def validate_codemod_registration (codemod_id : str ) -> BaseCodemod :
16+ codemod_registry = registry .load_registered_codemods ()
17+ try :
18+ return codemod_registry .match_codemods (codemod_include = [codemod_id ])[0 ]
19+ except IndexError as exc :
20+ raise IndexError (
21+ "You must register the codemod to a CodemodCollection."
22+ ) from exc
1323
1424
1525class DiffError (Exception ):
@@ -32,9 +42,14 @@ class BaseCodemodTest:
3242 def file_extension (self ) -> str :
3343 return "py"
3444
45+ @classmethod
46+ def setup_class (cls ):
47+ codemod_id = (
48+ cls .codemod ().id if isinstance (cls .codemod , type ) else cls .codemod .id
49+ )
50+ cls .codemod = validate_codemod_registration (codemod_id )
51+
3552 def setup_method (self ):
36- if isinstance (self .codemod , type ):
37- self .codemod = self .codemod ()
3853 self .changeset = []
3954
4055 def run_and_assert (
@@ -126,25 +141,6 @@ def run_and_assert_filepath(
126141 )
127142
128143
129- class BaseSemgrepCodemodTest (BaseCodemodTest ):
130- @classmethod
131- def setup_class (cls ):
132- collection = CodemodCollection (
133- origin = "pixee" ,
134- codemods = [cls .codemod ],
135- )
136- cls .registry = CodemodRegistry ()
137- cls .registry .add_codemod_collection (collection )
138-
139- def results_by_id_filepath (self , input_code , file_path ):
140- with open (file_path , "w" , encoding = "utf-8" ) as tmp_file :
141- tmp_file .write (dedent (input_code ))
142-
143- name = self .codemod .name
144- results = self .registry .match_codemods (codemod_include = [name ])
145- return semgrep_run (self .execution_context , results [0 ].yaml_files )
146-
147-
148144class BaseDjangoCodemodTest (BaseCodemodTest ):
149145 def create_dir_structure (self , tmpdir ):
150146 django_root = Path (tmpdir ) / "mysite"
0 commit comments