20
20
21
21
from itertools import groupby
22
22
from typing import List , Set , Dict
23
-
23
+ from tree_sitter import Language , Node , Parser , Query , Tree
24
24
import tree_sitter_java as tsjava
25
25
from tree_sitter import Language , Node , Parser , Query
26
26
@@ -51,10 +51,49 @@ def method_is_not_in_class(self, method_name: str, class_body: str) -> bool:
51
51
bool
52
52
True if the method is in the class, False otherwise.
53
53
"""
54
- methods_in_class = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @name)" , class_body )
54
+ methods_in_class = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @name)" ,
55
+ class_body )
55
56
56
57
return method_name not in {method .node .text .decode () for method in methods_in_class }
57
58
59
+ def is_parsable (self , code : str ) -> bool :
60
+ """
61
+ Check if the code is parsable
62
+ Args:
63
+ code: source code
64
+
65
+ Returns:
66
+ True if the code is parsable, False otherwise
67
+ """
68
+
69
+ def syntax_error (node ):
70
+ if node .type == "ERROR" :
71
+ return True
72
+ try :
73
+ for child in node .children :
74
+ if syntax_error (child ):
75
+ return True
76
+ except RecursionError as err :
77
+ return True
78
+
79
+ return False
80
+
81
+ tree = self .parser .parse (bytes (code , "utf-8" ))
82
+ if tree is not None :
83
+ return not syntax_error (tree .root_node )
84
+ return False
85
+
86
+ def get_raw_ast (self , code : str ) -> Tree :
87
+ """
88
+ Get the raw AST
89
+ Args:
90
+ code: source code
91
+
92
+ Returns:
93
+ Tree: the raw AST
94
+ """
95
+ return self .parser .parse (bytes (code , "utf-8" ))
96
+
58
97
def get_all_imports (self , source_code : str ) -> Set [str ]:
59
98
"""Get a list of all the imports in a class.
60
99
@@ -64,7 +103,8 @@ def get_all_imports(self, source_code: str) -> Set[str]:
64
103
Returns:
65
104
Set[str]: A set of all the imports in the class.
66
105
"""
67
- import_declerations : Captures = self .frame_query_and_capture_output (query = "(import_declaration (scoped_identifier) @name)" , code_to_process = source_code )
106
+ import_declerations : Captures = self .frame_query_and_capture_output (
107
+ query = "(import_declaration (scoped_identifier) @name)" , code_to_process = source_code )
68
108
return {capture .node .text .decode () for capture in import_declerations }
69
109
70
110
def get_pacakge_name (self , source_code : str ) -> str :
@@ -76,7 +116,8 @@ def get_pacakge_name(self, source_code: str) -> str:
76
116
Returns:
77
117
str: The package name.
78
118
"""
79
- package_name : Captures = self .frame_query_and_capture_output (query = "((package_declaration) @name)" , code_to_process = source_code )
119
+ package_name : Captures = self .frame_query_and_capture_output (query = "((package_declaration) @name)" ,
120
+ code_to_process = source_code )
80
121
if package_name :
81
122
return package_name [0 ].node .text .decode ().replace ("package " , "" ).replace (";" , "" )
82
123
return None
@@ -102,7 +143,8 @@ def get_superclass(self, source_code: str) -> str:
102
143
Returns:
103
144
Set[str]: A set of all the superclasses in the class.
104
145
"""
105
- superclass : Captures = self .frame_query_and_capture_output (query = "(class_declaration (superclass (type_identifier) @superclass))" , code_to_process = source_code )
146
+ superclass : Captures = self .frame_query_and_capture_output (
147
+ query = "(class_declaration (superclass (type_identifier) @superclass))" , code_to_process = source_code )
106
148
107
149
if len (superclass ) == 0 :
108
150
return ""
@@ -119,7 +161,9 @@ def get_all_interfaces(self, source_code: str) -> Set[str]:
119
161
Set[str]: A set of all the interfaces implemented by the class.
120
162
"""
121
163
122
- interfaces = self .frame_query_and_capture_output ("(class_declaration (super_interfaces (type_list (type_identifier) @interface)))" , code_to_process = source_code )
164
+ interfaces = self .frame_query_and_capture_output (
165
+ "(class_declaration (super_interfaces (type_list (type_identifier) @interface)))" ,
166
+ code_to_process = source_code )
123
167
return {interface .node .text .decode () for interface in interfaces }
124
168
125
169
def frame_query_and_capture_output (self , query : str , code_to_process : str ) -> Captures :
@@ -138,7 +182,8 @@ def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Ca
138
182
139
183
def get_method_name_from_declaration (self , method_name_string : str ) -> str :
140
184
"""Get the method name from the method signature."""
141
- captures : Captures = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @method_name)" , method_name_string )
185
+ captures : Captures = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @method_name)" ,
186
+ method_name_string )
142
187
143
188
return captures [0 ].node .text .decode ()
144
189
@@ -147,7 +192,8 @@ def get_method_name_from_invocation(self, method_invocation: str) -> str:
147
192
Using the tree-sitter query, extract the method name from the method invocation.
148
193
"""
149
194
150
- captures : Captures = self .frame_query_and_capture_output ("(method_invocation object: (identifier) @class_name name: (identifier) @method_name)" , method_invocation )
195
+ captures : Captures = self .frame_query_and_capture_output (
196
+ "(method_invocation object: (identifier) @class_name name: (identifier) @method_name)" , method_invocation )
151
197
return captures [0 ].node .text .decode ()
152
198
153
199
def safe_ascend (self , node : Node , ascend_count : int ) -> Node :
@@ -352,7 +398,8 @@ def get_method_return_type(self, source_code: str) -> str:
352
398
The return type of the method.
353
399
"""
354
400
355
- type_references : Captures = self .frame_query_and_capture_output ("(method_declaration type: ((type_identifier) @type_id))" , source_code )
401
+ type_references : Captures = self .frame_query_and_capture_output (
402
+ "(method_declaration type: ((type_identifier) @type_id))" , source_code )
356
403
357
404
return type_references [0 ].node .text .decode ()
358
405
@@ -379,9 +426,9 @@ def collect_leaf_token_values(node):
379
426
if len (node .children ) == 0 :
380
427
if filter_by_node_type is not None :
381
428
if node .type in filter_by_node_type :
382
- lexical_tokens .append (code [node .start_byte : node .end_byte ])
429
+ lexical_tokens .append (code [node .start_byte : node .end_byte ])
383
430
else :
384
- lexical_tokens .append (code [node .start_byte : node .end_byte ])
431
+ lexical_tokens .append (code [node .start_byte : node .end_byte ])
385
432
else :
386
433
for child in node .children :
387
434
collect_leaf_token_values (child )
@@ -415,9 +462,11 @@ def remove_all_comments(self, source_code: str) -> str:
415
462
pruned_source_code = self .make_pruned_code_prettier (source_code )
416
463
417
464
# Remove all comment lines: the comment lines start with / (for // and /*) or * (for multiline comments).
418
- comment_blocks : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" , code_to_process = source_code )
465
+ comment_blocks : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" ,
466
+ code_to_process = source_code )
419
467
420
- comment_lines : Captures = self .frame_query_and_capture_output (query = "((line_comment) @comment_line)" , code_to_process = source_code )
468
+ comment_lines : Captures = self .frame_query_and_capture_output (query = "((line_comment) @comment_line)" ,
469
+ code_to_process = source_code )
421
470
422
471
for capture in comment_blocks :
423
472
pruned_source_code = pruned_source_code .replace (capture .node .text .decode (), "" )
@@ -441,7 +490,8 @@ def make_pruned_code_prettier(self, pruned_code: str) -> str:
441
490
The prettified pruned code.
442
491
"""
443
492
# First remove remaining block comments
444
- block_comments : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" , code_to_process = pruned_code )
493
+ block_comments : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" ,
494
+ code_to_process = pruned_code )
445
495
446
496
for capture in block_comments :
447
497
pruned_code = pruned_code .replace (capture .node .text .decode (), "" )
0 commit comments