12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ from collections import defaultdict
15
16
from difflib import unified_diff
16
17
from pathlib import Path
17
18
from re import match
23
24
repo = Repo (__file__ , odbt = GitDB , search_parent_directories = True )
24
25
25
26
26
- file_path_symbols = {}
27
+ added_symbols = defaultdict (list )
28
+ removed_symbols = defaultdict (list )
27
29
28
30
29
31
def get_symbols (change_type , diff_lines_getter , prefix ):
32
+
33
+ if change_type == "D" or prefix == r"\-" :
34
+ file_path_symbols = removed_symbols
35
+ else :
36
+ file_path_symbols = added_symbols
37
+
30
38
for diff_lines in (
31
39
repo .commit ("main" )
32
40
.diff (repo .head .commit )
@@ -60,9 +68,6 @@ def get_symbols(change_type, diff_lines_getter, prefix):
60
68
)
61
69
62
70
if matching_line is not None :
63
- if b_file_path not in file_path_symbols .keys ():
64
- file_path_symbols [b_file_path ] = []
65
-
66
71
file_path_symbols [b_file_path ].append (
67
72
next (filter (bool , matching_line .groups ()))
68
73
)
@@ -71,6 +76,8 @@ def get_symbols(change_type, diff_lines_getter, prefix):
71
76
def a_diff_lines_getter (diff_lines ):
72
77
return diff_lines .b_blob .data_stream .read ().decode ("utf-8" ).split ("\n " )
73
78
79
+ def d_diff_lines_getter (diff_lines ):
80
+ return diff_lines .a_blob .data_stream .read ().decode ("utf-8" ).split ("\n " )
74
81
75
82
def m_diff_lines_getter (diff_lines ):
76
83
return unified_diff (
@@ -80,12 +87,42 @@ def m_diff_lines_getter(diff_lines):
80
87
81
88
82
89
get_symbols ("A" , a_diff_lines_getter , r"" )
90
+ get_symbols ("D" , d_diff_lines_getter , r"" )
83
91
get_symbols ("M" , m_diff_lines_getter , r"\+" )
92
+ get_symbols ("M" , m_diff_lines_getter , r"\-" )
93
+
94
+ def remove_common_symbols ():
95
+ # For each file, we remove the symbols that are added and removed in the
96
+ # same commit.
97
+ common_symbols = defaultdict (list )
98
+ for file_path , symbols in added_symbols .items ():
99
+ for symbol in symbols :
100
+ if symbol in removed_symbols [file_path ]:
101
+ common_symbols [file_path ].append (symbol )
102
+
103
+ for file_path , symbols in common_symbols .items ():
104
+ for symbol in symbols :
105
+ added_symbols [file_path ].remove (symbol )
106
+ removed_symbols [file_path ].remove (symbol )
107
+
108
+ # If a file has no added or removed symbols, we remove it from the
109
+ # dictionaries.
110
+ for file_path in list (added_symbols .keys ()):
111
+ if not added_symbols [file_path ]:
112
+ del added_symbols [file_path ]
113
+
114
+ for file_path in list (removed_symbols .keys ()):
115
+ if not removed_symbols [file_path ]:
116
+ del removed_symbols [file_path ]
117
+
118
+ if added_symbols or removed_symbols :
84
119
85
- if file_path_symbols :
120
+ # If a symbol is added and removed in the same commit, we consider it
121
+ # as not added or removed.
122
+ remove_common_symbols ()
86
123
print ("The code in this branch adds the following public symbols:" )
87
124
print ()
88
- for file_path , symbols in file_path_symbols .items ():
125
+ for file_path , symbols in added_symbols .items ():
89
126
print (f"- { file_path } " )
90
127
for symbol in symbols :
91
128
print (f"\t { symbol } " )
@@ -97,6 +134,20 @@ def m_diff_lines_getter(diff_lines):
97
134
'private. After that, please label this PR with "Skip Public API '
98
135
'check".'
99
136
)
137
+ print ()
138
+ print ("The code in this branch removes the following public symbols:" )
139
+ print ()
140
+ for file_path , symbols in removed_symbols .items ():
141
+ print (f"- { file_path } " )
142
+ for symbol in symbols :
143
+ print (f"\t { symbol } " )
144
+ print ()
145
+
146
+ print (
147
+ "Please make sure no public symbols are removed, if so, please "
148
+ "consider deprecating them instead. After that, please label this "
149
+ 'PR with "Skip Public API check".'
150
+ )
100
151
exit (1 )
101
152
else :
102
153
print ("The code in this branch will not add any public symbols" )
0 commit comments