Skip to content

Commit 8f4dc30

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Hide all symbols (except stable/headeronly/shim) if TORCH_STABLE_ONLY is defined (pytorch#167496)
Fixes pytorch#161660 This extends the `TORCH_STABLE_ONLY` stopgap added in pytorch#161658 Pull Request resolved: pytorch#167496 Approved by: https://github.com/janeyx99, https://github.com/malfet
1 parent be33b7f commit 8f4dc30

File tree

8 files changed

+388
-94
lines changed

8 files changed

+388
-94
lines changed

.ci/pytorch/smoke_test/check_binary_symbols.py

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,337 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
100100
)
101101

102102

103+
def _compile_and_extract_symbols(
104+
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
105+
) -> list[str]:
106+
"""
107+
Helper to compile a C++ file and extract all symbols.
108+
109+
Args:
110+
cpp_content: C++ source code to compile
111+
compile_flags: Compilation flags
112+
exclude_list: List of symbol names to exclude. Defaults to ["main"].
113+
114+
Returns:
115+
List of all symbols found in the object file (excluding those in exclude_list).
116+
"""
117+
import subprocess
118+
import tempfile
119+
120+
if exclude_list is None:
121+
exclude_list = ["main"]
122+
123+
with tempfile.TemporaryDirectory() as tmpdir:
124+
tmppath = Path(tmpdir)
125+
cpp_file = tmppath / "test.cpp"
126+
obj_file = tmppath / "test.o"
127+
128+
cpp_file.write_text(cpp_content)
129+
130+
result = subprocess.run(
131+
compile_flags + [str(cpp_file), "-o", str(obj_file)],
132+
capture_output=True,
133+
text=True,
134+
timeout=60,
135+
)
136+
137+
if result.returncode != 0:
138+
raise RuntimeError(f"Compilation failed: {result.stderr}")
139+
140+
symbols = get_symbols(str(obj_file))
141+
142+
# Return all symbol names, excluding those in the exclude list
143+
return [name for _addr, _stype, name in symbols if name not in exclude_list]
144+
145+
146+
def check_stable_only_symbols(install_root: Path) -> None:
147+
"""
148+
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
149+
150+
This approach tests:
151+
1. WITHOUT macros -> many torch symbols exposed
152+
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
153+
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
154+
4. WITH both macros -> zero torch symbols (all hidden)
155+
"""
156+
include_dir = install_root / "include"
157+
assert include_dir.exists(), f"Expected {include_dir} to be present"
158+
159+
test_cpp_content = """
160+
// Main torch C++ API headers
161+
#include <torch/torch.h>
162+
#include <torch/all.h>
163+
164+
// ATen tensor library
165+
#include <ATen/ATen.h>
166+
167+
// Core c10 headers (commonly used)
168+
#include <c10/core/Device.h>
169+
#include <c10/core/DeviceType.h>
170+
#include <c10/core/ScalarType.h>
171+
#include <c10/core/TensorOptions.h>
172+
#include <c10/util/Optional.h>
173+
174+
int main() { return 0; }
175+
"""
176+
177+
base_compile_flags = [
178+
"g++",
179+
"-std=c++17",
180+
f"-I{include_dir}",
181+
f"-I{include_dir}/torch/csrc/api/include",
182+
"-c", # Compile only, don't link
183+
]
184+
185+
# Compile WITHOUT any macros
186+
symbols_without = _compile_and_extract_symbols(
187+
cpp_content=test_cpp_content,
188+
compile_flags=base_compile_flags,
189+
)
190+
191+
# We expect constexpr symbols, inline functions used by other headers etc.
192+
# to produce symbols
193+
num_symbols_without = len(symbols_without)
194+
print(f"Found {num_symbols_without} symbols without any macros defined")
195+
assert num_symbols_without != 0, (
196+
"Expected a non-zero number of symbols without any macros"
197+
)
198+
199+
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
200+
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
201+
202+
symbols_with_stable_only = _compile_and_extract_symbols(
203+
cpp_content=test_cpp_content,
204+
compile_flags=compile_flags_with_stable_only,
205+
)
206+
207+
num_symbols_with_stable_only = len(symbols_with_stable_only)
208+
assert num_symbols_with_stable_only == 0, (
209+
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
210+
)
211+
212+
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
213+
compile_flags_with_target_version = base_compile_flags + [
214+
"-DTORCH_TARGET_VERSION=1"
215+
]
216+
217+
symbols_with_target_version = _compile_and_extract_symbols(
218+
cpp_content=test_cpp_content,
219+
compile_flags=compile_flags_with_target_version,
220+
)
221+
222+
num_symbols_with_target_version = len(symbols_with_target_version)
223+
assert num_symbols_with_target_version == 0, (
224+
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
225+
)
226+
227+
# Compile WITH both macros (expect 0 symbols)
228+
compile_flags_with_both = base_compile_flags + [
229+
"-DTORCH_STABLE_ONLY",
230+
"-DTORCH_TARGET_VERSION=1",
231+
]
232+
233+
symbols_with_both = _compile_and_extract_symbols(
234+
cpp_content=test_cpp_content,
235+
compile_flags=compile_flags_with_both,
236+
)
237+
238+
num_symbols_with_both = len(symbols_with_both)
239+
assert num_symbols_with_both == 0, (
240+
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
241+
)
242+
243+
244+
def check_stable_api_symbols(install_root: Path) -> None:
245+
"""
246+
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
247+
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
248+
"""
249+
include_dir = install_root / "include"
250+
assert include_dir.exists(), f"Expected {include_dir} to be present"
251+
252+
stable_dir = include_dir / "torch" / "csrc" / "stable"
253+
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
254+
255+
stable_headers = list(stable_dir.rglob("*.h"))
256+
if not stable_headers:
257+
raise RuntimeError("Could not find any stable headers")
258+
259+
includes = []
260+
for header in stable_headers:
261+
rel_path = header.relative_to(include_dir)
262+
includes.append(f"#include <{rel_path.as_posix()}>")
263+
264+
includes_str = "\n".join(includes)
265+
test_stable_content = f"""
266+
{includes_str}
267+
int main() {{ return 0; }}
268+
"""
269+
270+
compile_flags = [
271+
"g++",
272+
"-std=c++17",
273+
f"-I{include_dir}",
274+
f"-I{include_dir}/torch/csrc/api/include",
275+
"-c",
276+
"-DTORCH_STABLE_ONLY",
277+
]
278+
279+
symbols_stable = _compile_and_extract_symbols(
280+
cpp_content=test_stable_content,
281+
compile_flags=compile_flags,
282+
)
283+
num_symbols_stable = len(symbols_stable)
284+
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
285+
assert num_symbols_stable > 0, (
286+
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
287+
f"but found {num_symbols_stable} symbols"
288+
)
289+
290+
291+
def check_headeronly_symbols(install_root: Path) -> None:
292+
"""
293+
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
294+
"""
295+
include_dir = install_root / "include"
296+
assert include_dir.exists(), f"Expected {include_dir} to be present"
297+
298+
# Find all headers in torch/headeronly
299+
headeronly_dir = include_dir / "torch" / "headeronly"
300+
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
301+
headeronly_headers = list(headeronly_dir.rglob("*.h"))
302+
if not headeronly_headers:
303+
raise RuntimeError("Could not find any headeronly headers")
304+
305+
# Filter out platform-specific headers that may not compile everywhere
306+
platform_specific_keywords = [
307+
"cpu/vec",
308+
]
309+
310+
filtered_headers = []
311+
for header in headeronly_headers:
312+
rel_path = header.relative_to(include_dir).as_posix()
313+
if not any(
314+
keyword in rel_path.lower() for keyword in platform_specific_keywords
315+
):
316+
filtered_headers.append(header)
317+
318+
includes = []
319+
for header in filtered_headers:
320+
rel_path = header.relative_to(include_dir)
321+
includes.append(f"#include <{rel_path.as_posix()}>")
322+
323+
includes_str = "\n".join(includes)
324+
test_headeronly_content = f"""
325+
{includes_str}
326+
int main() {{ return 0; }}
327+
"""
328+
329+
compile_flags = [
330+
"g++",
331+
"-std=c++17",
332+
f"-I{include_dir}",
333+
f"-I{include_dir}/torch/csrc/api/include",
334+
"-c",
335+
"-DTORCH_STABLE_ONLY",
336+
]
337+
338+
symbols_headeronly = _compile_and_extract_symbols(
339+
cpp_content=test_headeronly_content,
340+
compile_flags=compile_flags,
341+
)
342+
num_symbols_headeronly = len(symbols_headeronly)
343+
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
344+
assert num_symbols_headeronly > 0, (
345+
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
346+
f"but found {num_symbols_headeronly} symbols"
347+
)
348+
349+
350+
def check_aoti_shim_symbols(install_root: Path) -> None:
351+
"""
352+
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
353+
"""
354+
include_dir = install_root / "include"
355+
assert include_dir.exists(), f"Expected {include_dir} to be present"
356+
357+
# There are no constexpr symbols etc., so we need to actually use functions
358+
# so that some symbols are found.
359+
test_shim_content = """
360+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
361+
int main() {
362+
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
363+
int32_t (*fp2)() = &aoti_torch_dtype_float32;
364+
(void)fp1; (void)fp2;
365+
return 0;
366+
}
367+
"""
368+
369+
compile_flags = [
370+
"g++",
371+
"-std=c++17",
372+
f"-I{include_dir}",
373+
f"-I{include_dir}/torch/csrc/api/include",
374+
"-c",
375+
"-DTORCH_STABLE_ONLY",
376+
]
377+
378+
symbols_shim = _compile_and_extract_symbols(
379+
cpp_content=test_shim_content,
380+
compile_flags=compile_flags,
381+
)
382+
num_symbols_shim = len(symbols_shim)
383+
assert num_symbols_shim > 0, (
384+
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
385+
f"but found {num_symbols_shim} symbols"
386+
)
387+
388+
389+
def check_stable_c_shim_symbols(install_root: Path) -> None:
390+
"""
391+
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
392+
"""
393+
include_dir = install_root / "include"
394+
assert include_dir.exists(), f"Expected {include_dir} to be present"
395+
396+
# Check if the stable C shim exists
397+
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
398+
if not stable_shim.exists():
399+
raise RuntimeError("Could not find stable c shim")
400+
401+
# There are no constexpr symbols etc., so we need to actually use functions
402+
# so that some symbols are found.
403+
test_stable_shim_content = """
404+
#include <torch/csrc/stable/c/shim.h>
405+
int main() {
406+
// Reference stable C API functions to create undefined symbols
407+
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
408+
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
409+
(void)fp1; (void)fp2;
410+
return 0;
411+
}
412+
"""
413+
414+
compile_flags = [
415+
"g++",
416+
"-std=c++17",
417+
f"-I{include_dir}",
418+
f"-I{include_dir}/torch/csrc/api/include",
419+
"-c",
420+
"-DTORCH_STABLE_ONLY",
421+
]
422+
423+
symbols_stable_shim = _compile_and_extract_symbols(
424+
cpp_content=test_stable_shim_content,
425+
compile_flags=compile_flags,
426+
)
427+
num_symbols_stable_shim = len(symbols_stable_shim)
428+
assert num_symbols_stable_shim > 0, (
429+
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
430+
f"but found {num_symbols_stable_shim} symbols"
431+
)
432+
433+
103434
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
104435
print(f"lib: {lib}")
105436
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
@@ -129,6 +460,13 @@ def main() -> None:
129460
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
130461
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
131462

463+
# Check symbols when TORCH_STABLE_ONLY is defined
464+
check_stable_only_symbols(install_root)
465+
check_stable_api_symbols(install_root)
466+
check_headeronly_symbols(install_root)
467+
check_aoti_shim_symbols(install_root)
468+
check_stable_c_shim_symbols(install_root)
469+
132470

133471
if __name__ == "__main__":
134472
main()

0 commit comments

Comments
 (0)