Skip to content

Commit b163347

Browse files
committed
refactor: use gre_ast to determine language
1 parent 34774cf commit b163347

File tree

7 files changed

+117
-144
lines changed

7 files changed

+117
-144
lines changed

src/mutahunter/core/analyzer.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
import shlex
1+
import os
22
import subprocess
33
import xml.etree.ElementTree as ET
44
from importlib import resources
5+
from shlex import split
56
from typing import Any, Dict, List
67

8+
from grep_ast import filename_to_lang
79
from tree_sitter_languages import get_language, get_parser
810

911

@@ -17,7 +19,6 @@ def __init__(self, config: Dict[str, Any]) -> None:
1719
"""
1820
super().__init__()
1921
self.config = config
20-
self.tree_sitter_parser = get_parser(self.config["language"])
2122

2223
if self.config["coverage_type"] == "cobertura":
2324
self.file_lines_executed = self.parse_coverage_report_cobertura()
@@ -47,9 +48,9 @@ def parse_coverage_report_lcov(self) -> Dict[str, List[int]]:
4748
result[current_file] = []
4849
elif line.startswith("DA:") and current_file:
4950
parts = line.strip().split(":")[1].split(",")
50-
line_number = int(parts[0])
5151
hits = int(parts[1])
5252
if hits > 0:
53+
line_number = int(parts[0])
5354
result[current_file].append(line_number)
5455
elif line.startswith("end_of_record"):
5556
current_file = None
@@ -117,28 +118,28 @@ def dry_run(self) -> None:
117118
Exception: If any tests fail during the dry run.
118119
"""
119120
test_command = self.config["test_command"]
120-
result = subprocess.run(shlex.split(test_command))
121+
result = subprocess.run(split(test_command), cwd=os.getcwd())
121122
if result.returncode != 0:
122123
raise Exception(
123124
"Tests failed. Please ensure all tests pass before running mutation testing."
124125
)
125126

126127
def get_covered_function_blocks(
127-
self, executed_lines: List[int], filename: str
128+
self, executed_lines: List[int], source_file_path: str
128129
) -> List[Any]:
129130
"""
130-
Retrieves covered function blocks based on executed lines and filename.
131+
Retrieves covered function blocks based on executed lines and source_file_path.
131132
132133
Args:
133134
executed_lines (List[int]): List of executed line numbers.
134-
filename (str): The name of the file being analyzed.
135+
source_file_path (str): The name of the file being analyzed.
135136
136137
Returns:
137138
List[Any]: A list of covered function blocks.
138139
"""
139140
covered_function_blocks = []
140141
covered_function_block_executed_lines = []
141-
function_blocks = self.get_function_blocks(filename=filename)
142+
function_blocks = self.get_function_blocks(source_file_path=source_file_path)
142143
for function_block in function_blocks:
143144
start_point = function_block.start_point
144145
end_point = function_block.end_point
@@ -152,16 +153,15 @@ def get_covered_function_blocks(
152153
if any(
153154
line in executed_lines for line in range(start_line + 1, end_line + 1)
154155
): # start_line + 1 to exclude the function definition line
155-
function_executed_lines = []
156-
for line in range(start_line, end_line + 1):
157-
function_executed_lines.append(line - start_line + 1)
158-
156+
function_executed_lines = [
157+
line - start_line + 1 for line in range(start_line, end_line + 1)
158+
]
159159
covered_function_blocks.append(function_block)
160160
covered_function_block_executed_lines.append(function_executed_lines)
161161

162162
return covered_function_blocks, covered_function_block_executed_lines
163163

164-
def get_function_blocks(self, filename: str) -> List[Any]:
164+
def get_function_blocks(self, source_file_path: str) -> List[Any]:
165165
"""
166166
Retrieves function blocks from a given file.
167167
@@ -171,12 +171,13 @@ def get_function_blocks(self, filename: str) -> List[Any]:
171171
Returns:
172172
List[Any]: A list of function block nodes.
173173
"""
174-
with open(filename, "rb") as f:
174+
with open(source_file_path, "rb") as f:
175175
source_code = f.read()
176-
function_blocks = self.find_function_blocks_nodes(source_code=source_code)
177-
return function_blocks
176+
return self.find_function_blocks_nodes(
177+
source_file_path=source_file_path, source_code=source_code
178+
)
178179

179-
def check_syntax(self, source_code: str) -> bool:
180+
def check_syntax(self, source_file_path: str, source_code: str) -> bool:
180181
"""
181182
Checks the syntax of the provided source code.
182183
@@ -186,11 +187,16 @@ def check_syntax(self, source_code: str) -> bool:
186187
Returns:
187188
bool: True if the syntax is correct, False otherwise.
188189
"""
189-
tree = self.tree_sitter_parser.parse(bytes(source_code, "utf8"))
190+
191+
lang = filename_to_lang(source_file_path)
192+
parser = get_parser(lang)
193+
tree = parser.parse(bytes(source_code, "utf8"))
190194
root_node = tree.root_node
191195
return not root_node.has_error
192196

193-
def find_function_blocks_nodes(self, source_code: bytes) -> List[Any]:
197+
def find_function_blocks_nodes(
198+
self, source_file_path: str, source_code: bytes
199+
) -> List[Any]:
194200
"""
195201
Finds function block nodes in the provided source code.
196202
@@ -200,10 +206,11 @@ def find_function_blocks_nodes(self, source_code: bytes) -> List[Any]:
200206
Returns:
201207
List[Any]: A list of function block nodes.
202208
"""
203-
tree = self.tree_sitter_parser.parse(source_code)
204-
function_blocks = []
205-
lang = self.config["language"]
209+
lang = filename_to_lang(source_file_path)
210+
parser = get_parser(lang)
206211
language = get_language(lang)
212+
213+
tree = parser.parse(source_code)
207214
# Load the tags queries
208215
try:
209216
scm_fname = resources.files(__package__).joinpath(
@@ -219,12 +226,8 @@ def find_function_blocks_nodes(self, source_code: bytes) -> List[Any]:
219226
captures = query.captures(tree.root_node)
220227

221228
captures = list(captures)
222-
for node, tag in captures:
223-
if tag == "definition.function" or tag == "definition.method":
224-
function_blocks.append(node)
225-
# start_byte = node.start_byte
226-
# end_byte = node.end_byte
227-
# func_code = source_code[start_byte:end_byte].decode("utf-8")
228-
# print(f"Function block code: {func_code}")
229-
230-
return function_blocks
229+
return [
230+
node
231+
for node, tag in captures
232+
if tag in ["definition.function", "definition.method"]
233+
]

src/mutahunter/core/hunter.py

Lines changed: 26 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -28,46 +28,11 @@ def __init__(self, config: Dict[str, Any]) -> None:
2828
- only_mutate_file_paths (List[str]): List of specific files to mutate.
2929
"""
3030
self.config: Dict[str, Any] = config
31-
self.config["language"] = self.determine_language(config["test_file_path"])
3231
self.mutants: List[Mutant] = []
3332
self.mutant_report = MutantReport(config=self.config)
3433
self.analyzer = Analyzer(self.config)
3534
self.test_runner = TestRunner()
3635

37-
def determine_language(self, filename: str) -> str:
38-
"""
39-
Determines the programming language based on the file extension. For Tree-Sitter language detection.
40-
41-
Args:
42-
filename (str): The filename to determine the language from.
43-
44-
Returns:
45-
str: The programming language corresponding to the file extension.
46-
47-
Raises:
48-
ValueError: If the file extension is not supported.
49-
"""
50-
ext = filename.split(".")[-1]
51-
language_mappings = {
52-
"py": "python",
53-
"java": "java",
54-
"js": "javascript",
55-
"ts": "typescript",
56-
"c": "c",
57-
"cpp": "cpp",
58-
"rs": "rust",
59-
"go": "go",
60-
"php": "php",
61-
"rb": "ruby",
62-
"swift": "swift",
63-
"kt": "kotlin",
64-
"tsx": "tsx",
65-
"ts": "typescript",
66-
}
67-
if ext not in language_mappings:
68-
raise ValueError(f"Unsupported file extension: {ext}")
69-
return language_mappings[ext]
70-
7136
def run(self) -> None:
7237
"""
7338
Executes the mutation testing process from start to finish.
@@ -109,19 +74,14 @@ def should_skip_file(self, filename: str) -> bool:
10974
logger.error(f"File {file_path} does not exist.")
11075
raise FileNotFoundError(f"File {file_path} does not exist.")
11176
# NOTE: Only mutate the files specified in the config.
112-
for file_path in self.config["only_mutate_file_paths"]:
113-
if file_path == filename:
114-
return False
77+
return all(
78+
file_path != filename
79+
for file_path in self.config["only_mutate_file_paths"]
80+
)
81+
if filename in self.config["exclude_files"]:
11582
return True
116-
else:
117-
if filename in self.config["exclude_files"]:
118-
return True
119-
if any(
120-
keyword in filename
121-
for keyword in ["test/", "tests/", "test_", "_test", ".test"]
122-
):
123-
return True
124-
return False
83+
test_keywords = ["test/", "tests/", "test_", "_test", ".test"]
84+
return any(keyword in filename for keyword in test_keywords)
12585

12686
def generate_mutations(self) -> Generator[Dict[str, Any], None, None]:
12787
"""
@@ -131,47 +91,41 @@ def generate_mutations(self) -> Generator[Dict[str, Any], None, None]:
13191
Generator[Dict[str, Any], None, None]: Dictionary containing mutation details.
13292
"""
13393
all_covered_files = self.analyzer.file_lines_executed.keys()
134-
for filename in tqdm(all_covered_files):
135-
if self.should_skip_file(filename):
94+
for covered_file_path in tqdm(all_covered_files):
95+
if self.should_skip_file(covered_file_path):
13696
continue
13797
covered_function_blocks, covered_function_block_executed_lines = (
13898
self.analyzer.get_covered_function_blocks(
139-
executed_lines=self.analyzer.file_lines_executed[filename],
140-
filename=filename,
99+
executed_lines=self.analyzer.file_lines_executed[covered_file_path],
100+
source_file_path=covered_file_path,
141101
)
142102
)
143103
logger.info(
144-
f"Total function blocks to mutate: {len(covered_function_blocks)}"
104+
f"Total function blocks to mutate for {covered_file_path}: {len(covered_function_blocks)}"
145105
)
146106
if not covered_function_blocks:
147107
continue
148108

149-
with open(filename, "rb") as f:
150-
source_code = f.read()
151-
152109
for function_block, executed_lines in zip(
153110
covered_function_blocks,
154111
covered_function_block_executed_lines,
155112
):
156113
start_byte = function_block.start_byte
157114
end_byte = function_block.end_byte
158-
function_block_source_code = source_code[start_byte:end_byte].decode(
159-
"utf-8"
160-
)
161115

162116
mutant_generator = MutantGenerator(
163117
config=self.config,
164118
executed_lines=executed_lines,
165119
cov_files=list(all_covered_files),
166120
test_file_path=self.config["test_file_path"],
167-
filename=filename,
168-
function_block_source_code=function_block_source_code,
169-
language=self.config["language"],
121+
source_file_path=covered_file_path,
122+
start_byte=start_byte,
123+
end_byte=end_byte,
170124
)
171125

172-
for path, hunk, content in mutant_generator.generate():
126+
for _, hunk, content in mutant_generator.generate():
173127
yield {
174-
"source_path": filename,
128+
"source_path": covered_file_path,
175129
"start_byte": start_byte,
176130
"end_byte": end_byte,
177131
"hunk": hunk,
@@ -188,7 +142,7 @@ def run_mutation_testing(self) -> None:
188142
mutant_id = str(len(self.mutants) + 1)
189143
mutant_path = self.prepare_mutant_file(
190144
mutant_id=mutant_id,
191-
source_path=mutant_data["source_path"],
145+
source_file_path=mutant_data["source_path"],
192146
start_byte=mutant_data["start_byte"],
193147
end_byte=mutant_data["end_byte"],
194148
mutant_code=mutant_data["mutant_code_snippet"],
@@ -215,7 +169,7 @@ def run_mutation_testing(self) -> None:
215169
def prepare_mutant_file(
216170
self,
217171
mutant_id: str,
218-
source_path: str,
172+
source_file_path: str,
219173
start_byte: int,
220174
end_byte: int,
221175
mutant_code: str,
@@ -236,12 +190,12 @@ def prepare_mutant_file(
236190
Raises:
237191
Exception: If the mutant code has syntax errors.
238192
"""
239-
mutant_file_name = f"{mutant_id}_{os.path.basename(source_path)}"
193+
mutant_file_name = f"{mutant_id}_{os.path.basename(source_file_path)}"
240194
mutant_path = os.path.join(
241195
os.getcwd(), f"logs/_latest/mutants/{mutant_file_name}"
242196
)
243197

244-
with open(source_path, "rb") as f:
198+
with open(source_file_path, "rb") as f:
245199
source_code = f.read()
246200

247201
modified_byte_code = (
@@ -250,12 +204,15 @@ def prepare_mutant_file(
250204
+ source_code[end_byte:]
251205
)
252206

253-
if self.analyzer.check_syntax(modified_byte_code.decode("utf-8")):
207+
if self.analyzer.check_syntax(
208+
source_file_path=source_file_path,
209+
source_code=modified_byte_code.decode("utf-8"),
210+
):
254211
with open(mutant_path, "wb") as f:
255212
f.write(modified_byte_code)
256213
return mutant_path
257214
else:
258-
raise Exception("Mutant code has syntax errors.")
215+
raise SyntaxError("Mutant code has syntax errors.")
259216

260217
def run_test(self, params: Dict[str, str]) -> Any:
261218
"""

0 commit comments

Comments
 (0)