Skip to content

Commit e183935

Browse files
srikanthccvocelotl
andauthored
update public_symbols_checker.py to detect removed symbols (#3159)
* update public_symbols_checker.py to detect removed symbols * Update scripts/public_symbols_checker.py --------- Co-authored-by: Diego Hurtado <[email protected]>
1 parent 34d11d5 commit e183935

File tree

1 file changed

+57
-6
lines changed

1 file changed

+57
-6
lines changed

scripts/public_symbols_checker.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import defaultdict
1516
from difflib import unified_diff
1617
from pathlib import Path
1718
from re import match
@@ -23,10 +24,17 @@
2324
repo = Repo(__file__, odbt=GitDB, search_parent_directories=True)
2425

2526

26-
file_path_symbols = {}
27+
added_symbols = defaultdict(list)
28+
removed_symbols = defaultdict(list)
2729

2830

2931
def get_symbols(change_type, diff_lines_getter, prefix):
32+
33+
if change_type == "D" or prefix == r"\-":
34+
file_path_symbols = removed_symbols
35+
else:
36+
file_path_symbols = added_symbols
37+
3038
for diff_lines in (
3139
repo.commit("main")
3240
.diff(repo.head.commit)
@@ -60,9 +68,6 @@ def get_symbols(change_type, diff_lines_getter, prefix):
6068
)
6169

6270
if matching_line is not None:
63-
if b_file_path not in file_path_symbols.keys():
64-
file_path_symbols[b_file_path] = []
65-
6671
file_path_symbols[b_file_path].append(
6772
next(filter(bool, matching_line.groups()))
6873
)
@@ -71,6 +76,8 @@ def get_symbols(change_type, diff_lines_getter, prefix):
7176
def a_diff_lines_getter(diff_lines):
7277
return diff_lines.b_blob.data_stream.read().decode("utf-8").split("\n")
7378

79+
def d_diff_lines_getter(diff_lines):
80+
return diff_lines.a_blob.data_stream.read().decode("utf-8").split("\n")
7481

7582
def m_diff_lines_getter(diff_lines):
7683
return unified_diff(
@@ -80,12 +87,42 @@ def m_diff_lines_getter(diff_lines):
8087

8188

8289
get_symbols("A", a_diff_lines_getter, r"")
90+
get_symbols("D", d_diff_lines_getter, r"")
8391
get_symbols("M", m_diff_lines_getter, r"\+")
92+
get_symbols("M", m_diff_lines_getter, r"\-")
93+
94+
def remove_common_symbols():
95+
# For each file, we remove the symbols that are added and removed in the
96+
# same commit.
97+
common_symbols = defaultdict(list)
98+
for file_path, symbols in added_symbols.items():
99+
for symbol in symbols:
100+
if symbol in removed_symbols[file_path]:
101+
common_symbols[file_path].append(symbol)
102+
103+
for file_path, symbols in common_symbols.items():
104+
for symbol in symbols:
105+
added_symbols[file_path].remove(symbol)
106+
removed_symbols[file_path].remove(symbol)
107+
108+
# If a file has no added or removed symbols, we remove it from the
109+
# dictionaries.
110+
for file_path in list(added_symbols.keys()):
111+
if not added_symbols[file_path]:
112+
del added_symbols[file_path]
113+
114+
for file_path in list(removed_symbols.keys()):
115+
if not removed_symbols[file_path]:
116+
del removed_symbols[file_path]
117+
118+
if added_symbols or removed_symbols:
84119

85-
if file_path_symbols:
120+
# If a symbol is added and removed in the same commit, we consider it
121+
# as not added or removed.
122+
remove_common_symbols()
86123
print("The code in this branch adds the following public symbols:")
87124
print()
88-
for file_path, symbols in file_path_symbols.items():
125+
for file_path, symbols in added_symbols.items():
89126
print(f"- {file_path}")
90127
for symbol in symbols:
91128
print(f"\t{symbol}")
@@ -97,6 +134,20 @@ def m_diff_lines_getter(diff_lines):
97134
'private. After that, please label this PR with "Skip Public API '
98135
'check".'
99136
)
137+
print()
138+
print("The code in this branch removes the following public symbols:")
139+
print()
140+
for file_path, symbols in removed_symbols.items():
141+
print(f"- {file_path}")
142+
for symbol in symbols:
143+
print(f"\t{symbol}")
144+
print()
145+
146+
print(
147+
"Please make sure no public symbols are removed, if so, please "
148+
"consider deprecating them instead. After that, please label this "
149+
'PR with "Skip Public API check".'
150+
)
100151
exit(1)
101152
else:
102153
print("The code in this branch will not add any public symbols")

0 commit comments

Comments
 (0)