@@ -100,6 +100,337 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
100100 )
101101
102102
103+ def _compile_and_extract_symbols (
104+ cpp_content : str , compile_flags : list [str ], exclude_list : list [str ] | None = None
105+ ) -> list [str ]:
106+ """
107+ Helper to compile a C++ file and extract all symbols.
108+
109+ Args:
110+ cpp_content: C++ source code to compile
111+ compile_flags: Compilation flags
112+ exclude_list: List of symbol names to exclude. Defaults to ["main"].
113+
114+ Returns:
115+ List of all symbols found in the object file (excluding those in exclude_list).
116+ """
117+ import subprocess
118+ import tempfile
119+
120+ if exclude_list is None :
121+ exclude_list = ["main" ]
122+
123+ with tempfile .TemporaryDirectory () as tmpdir :
124+ tmppath = Path (tmpdir )
125+ cpp_file = tmppath / "test.cpp"
126+ obj_file = tmppath / "test.o"
127+
128+ cpp_file .write_text (cpp_content )
129+
130+ result = subprocess .run (
131+ compile_flags + [str (cpp_file ), "-o" , str (obj_file )],
132+ capture_output = True ,
133+ text = True ,
134+ timeout = 60 ,
135+ )
136+
137+ if result .returncode != 0 :
138+ raise RuntimeError (f"Compilation failed: { result .stderr } " )
139+
140+ symbols = get_symbols (str (obj_file ))
141+
142+ # Return all symbol names, excluding those in the exclude list
143+ return [name for _addr , _stype , name in symbols if name not in exclude_list ]
144+
145+
146+ def check_stable_only_symbols (install_root : Path ) -> None :
147+ """
148+ Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
149+
150+ This approach tests:
151+ 1. WITHOUT macros -> many torch symbols exposed
152+ 2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
153+ 3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
154+ 4. WITH both macros -> zero torch symbols (all hidden)
155+ """
156+ include_dir = install_root / "include"
157+ assert include_dir .exists (), f"Expected { include_dir } to be present"
158+
159+ test_cpp_content = """
160+ // Main torch C++ API headers
161+ #include <torch/torch.h>
162+ #include <torch/all.h>
163+
164+ // ATen tensor library
165+ #include <ATen/ATen.h>
166+
167+ // Core c10 headers (commonly used)
168+ #include <c10/core/Device.h>
169+ #include <c10/core/DeviceType.h>
170+ #include <c10/core/ScalarType.h>
171+ #include <c10/core/TensorOptions.h>
172+ #include <c10/util/Optional.h>
173+
174+ int main() { return 0; }
175+ """
176+
177+ base_compile_flags = [
178+ "g++" ,
179+ "-std=c++17" ,
180+ f"-I{ include_dir } " ,
181+ f"-I{ include_dir } /torch/csrc/api/include" ,
182+ "-c" , # Compile only, don't link
183+ ]
184+
185+ # Compile WITHOUT any macros
186+ symbols_without = _compile_and_extract_symbols (
187+ cpp_content = test_cpp_content ,
188+ compile_flags = base_compile_flags ,
189+ )
190+
191+ # We expect constexpr symbols, inline functions used by other headers etc.
192+ # to produce symbols
193+ num_symbols_without = len (symbols_without )
194+ print (f"Found { num_symbols_without } symbols without any macros defined" )
195+ assert num_symbols_without != 0 , (
196+ "Expected a non-zero number of symbols without any macros"
197+ )
198+
199+ # Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
200+ compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY" ]
201+
202+ symbols_with_stable_only = _compile_and_extract_symbols (
203+ cpp_content = test_cpp_content ,
204+ compile_flags = compile_flags_with_stable_only ,
205+ )
206+
207+ num_symbols_with_stable_only = len (symbols_with_stable_only )
208+ assert num_symbols_with_stable_only == 0 , (
209+ f"Expected no symbols with TORCH_STABLE_ONLY macro, but found { num_symbols_with_stable_only } "
210+ )
211+
212+ # Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
213+ compile_flags_with_target_version = base_compile_flags + [
214+ "-DTORCH_TARGET_VERSION=1"
215+ ]
216+
217+ symbols_with_target_version = _compile_and_extract_symbols (
218+ cpp_content = test_cpp_content ,
219+ compile_flags = compile_flags_with_target_version ,
220+ )
221+
222+ num_symbols_with_target_version = len (symbols_with_target_version )
223+ assert num_symbols_with_target_version == 0 , (
224+ f"Expected no symbols with TORCH_TARGET_VERSION macro, but found { num_symbols_with_target_version } "
225+ )
226+
227+ # Compile WITH both macros (expect 0 symbols)
228+ compile_flags_with_both = base_compile_flags + [
229+ "-DTORCH_STABLE_ONLY" ,
230+ "-DTORCH_TARGET_VERSION=1" ,
231+ ]
232+
233+ symbols_with_both = _compile_and_extract_symbols (
234+ cpp_content = test_cpp_content ,
235+ compile_flags = compile_flags_with_both ,
236+ )
237+
238+ num_symbols_with_both = len (symbols_with_both )
239+ assert num_symbols_with_both == 0 , (
240+ f"Expected no symbols with both macros, but found { num_symbols_with_both } "
241+ )
242+
243+
244+ def check_stable_api_symbols (install_root : Path ) -> None :
245+ """
246+ Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
247+ The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
248+ """
249+ include_dir = install_root / "include"
250+ assert include_dir .exists (), f"Expected { include_dir } to be present"
251+
252+ stable_dir = include_dir / "torch" / "csrc" / "stable"
253+ assert stable_dir .exists (), f"Expected { stable_dir } to be present"
254+
255+ stable_headers = list (stable_dir .rglob ("*.h" ))
256+ if not stable_headers :
257+ raise RuntimeError ("Could not find any stable headers" )
258+
259+ includes = []
260+ for header in stable_headers :
261+ rel_path = header .relative_to (include_dir )
262+ includes .append (f"#include <{ rel_path .as_posix ()} >" )
263+
264+ includes_str = "\n " .join (includes )
265+ test_stable_content = f"""
266+ { includes_str }
267+ int main() {{ return 0; }}
268+ """
269+
270+ compile_flags = [
271+ "g++" ,
272+ "-std=c++17" ,
273+ f"-I{ include_dir } " ,
274+ f"-I{ include_dir } /torch/csrc/api/include" ,
275+ "-c" ,
276+ "-DTORCH_STABLE_ONLY" ,
277+ ]
278+
279+ symbols_stable = _compile_and_extract_symbols (
280+ cpp_content = test_stable_content ,
281+ compile_flags = compile_flags ,
282+ )
283+ num_symbols_stable = len (symbols_stable )
284+ print (f"Found { num_symbols_stable } symbols in torch/csrc/stable" )
285+ assert num_symbols_stable > 0 , (
286+ f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
287+ f"but found { num_symbols_stable } symbols"
288+ )
289+
290+
291+ def check_headeronly_symbols (install_root : Path ) -> None :
292+ """
293+ Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
294+ """
295+ include_dir = install_root / "include"
296+ assert include_dir .exists (), f"Expected { include_dir } to be present"
297+
298+ # Find all headers in torch/headeronly
299+ headeronly_dir = include_dir / "torch" / "headeronly"
300+ assert headeronly_dir .exists (), f"Expected { headeronly_dir } to be present"
301+ headeronly_headers = list (headeronly_dir .rglob ("*.h" ))
302+ if not headeronly_headers :
303+ raise RuntimeError ("Could not find any headeronly headers" )
304+
305+ # Filter out platform-specific headers that may not compile everywhere
306+ platform_specific_keywords = [
307+ "cpu/vec" ,
308+ ]
309+
310+ filtered_headers = []
311+ for header in headeronly_headers :
312+ rel_path = header .relative_to (include_dir ).as_posix ()
313+ if not any (
314+ keyword in rel_path .lower () for keyword in platform_specific_keywords
315+ ):
316+ filtered_headers .append (header )
317+
318+ includes = []
319+ for header in filtered_headers :
320+ rel_path = header .relative_to (include_dir )
321+ includes .append (f"#include <{ rel_path .as_posix ()} >" )
322+
323+ includes_str = "\n " .join (includes )
324+ test_headeronly_content = f"""
325+ { includes_str }
326+ int main() {{ return 0; }}
327+ """
328+
329+ compile_flags = [
330+ "g++" ,
331+ "-std=c++17" ,
332+ f"-I{ include_dir } " ,
333+ f"-I{ include_dir } /torch/csrc/api/include" ,
334+ "-c" ,
335+ "-DTORCH_STABLE_ONLY" ,
336+ ]
337+
338+ symbols_headeronly = _compile_and_extract_symbols (
339+ cpp_content = test_headeronly_content ,
340+ compile_flags = compile_flags ,
341+ )
342+ num_symbols_headeronly = len (symbols_headeronly )
343+ print (f"Found { num_symbols_headeronly } symbols in torch/headeronly" )
344+ assert num_symbols_headeronly > 0 , (
345+ f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
346+ f"but found { num_symbols_headeronly } symbols"
347+ )
348+
349+
350+ def check_aoti_shim_symbols (install_root : Path ) -> None :
351+ """
352+ Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
353+ """
354+ include_dir = install_root / "include"
355+ assert include_dir .exists (), f"Expected { include_dir } to be present"
356+
357+ # There are no constexpr symbols etc., so we need to actually use functions
358+ # so that some symbols are found.
359+ test_shim_content = """
360+ #include <torch/csrc/inductor/aoti_torch/c/shim.h>
361+ int main() {
362+ int32_t (*fp1)() = &aoti_torch_device_type_cpu;
363+ int32_t (*fp2)() = &aoti_torch_dtype_float32;
364+ (void)fp1; (void)fp2;
365+ return 0;
366+ }
367+ """
368+
369+ compile_flags = [
370+ "g++" ,
371+ "-std=c++17" ,
372+ f"-I{ include_dir } " ,
373+ f"-I{ include_dir } /torch/csrc/api/include" ,
374+ "-c" ,
375+ "-DTORCH_STABLE_ONLY" ,
376+ ]
377+
378+ symbols_shim = _compile_and_extract_symbols (
379+ cpp_content = test_shim_content ,
380+ compile_flags = compile_flags ,
381+ )
382+ num_symbols_shim = len (symbols_shim )
383+ assert num_symbols_shim > 0 , (
384+ f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
385+ f"but found { num_symbols_shim } symbols"
386+ )
387+
388+
389+ def check_stable_c_shim_symbols (install_root : Path ) -> None :
390+ """
391+ Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
392+ """
393+ include_dir = install_root / "include"
394+ assert include_dir .exists (), f"Expected { include_dir } to be present"
395+
396+ # Check if the stable C shim exists
397+ stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
398+ if not stable_shim .exists ():
399+ raise RuntimeError ("Could not find stable c shim" )
400+
401+ # There are no constexpr symbols etc., so we need to actually use functions
402+ # so that some symbols are found.
403+ test_stable_shim_content = """
404+ #include <torch/csrc/stable/c/shim.h>
405+ int main() {
406+ // Reference stable C API functions to create undefined symbols
407+ AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
408+ AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
409+ (void)fp1; (void)fp2;
410+ return 0;
411+ }
412+ """
413+
414+ compile_flags = [
415+ "g++" ,
416+ "-std=c++17" ,
417+ f"-I{ include_dir } " ,
418+ f"-I{ include_dir } /torch/csrc/api/include" ,
419+ "-c" ,
420+ "-DTORCH_STABLE_ONLY" ,
421+ ]
422+
423+ symbols_stable_shim = _compile_and_extract_symbols (
424+ cpp_content = test_stable_shim_content ,
425+ compile_flags = compile_flags ,
426+ )
427+ num_symbols_stable_shim = len (symbols_stable_shim )
428+ assert num_symbols_stable_shim > 0 , (
429+ f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
430+ f"but found { num_symbols_stable_shim } symbols"
431+ )
432+
433+
103434def check_lib_symbols_for_abi_correctness (lib : str ) -> None :
104435 print (f"lib: { lib } " )
105436 cxx11_symbols = grep_symbols (lib , LIBTORCH_CXX11_PATTERNS )
@@ -129,6 +460,13 @@ def main() -> None:
129460 check_lib_symbols_for_abi_correctness (libtorch_cpu_path )
130461 check_lib_statically_linked_libstdc_cxx_abi_symbols (libtorch_cpu_path )
131462
463+ # Check symbols when TORCH_STABLE_ONLY is defined
464+ check_stable_only_symbols (install_root )
465+ check_stable_api_symbols (install_root )
466+ check_headeronly_symbols (install_root )
467+ check_aoti_shim_symbols (install_root )
468+ check_stable_c_shim_symbols (install_root )
469+
132470
133471if __name__ == "__main__" :
134472 main ()
0 commit comments