Skip to content

Commit a0ccd3e

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Error when non stable/headeronly/shim headers are included by stable extension (pytorch#167855)
Address Nikita's offline comment on pytorch#167496 Pull Request resolved: pytorch#167855 Approved by: https://github.com/janeyx99 ghstack dependencies: pytorch#167496
1 parent 8f4dc30 commit a0ccd3e

File tree

2 files changed

+111
-85
lines changed

2 files changed

+111
-85
lines changed

.ci/pytorch/smoke_test/check_binary_symbols.py

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,17 @@ def _compile_and_extract_symbols(
145145

146146
def check_stable_only_symbols(install_root: Path) -> None:
147147
"""
148-
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
148+
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code.
149149
150150
This approach tests:
151-
1. WITHOUT macros -> many torch symbols exposed
152-
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
153-
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
154-
4. WITH both macros -> zero torch symbols (all hidden)
151+
1. WITHOUT macros -> many torch symbols exposed (compilation succeeds)
152+
2. WITH TORCH_STABLE_ONLY -> compilation fails with #error directive
153+
3. WITH TORCH_TARGET_VERSION -> compilation fails with #error directive
154+
4. WITH both macros -> compilation fails with #error directive
155155
"""
156+
import subprocess
157+
import tempfile
158+
156159
include_dir = install_root / "include"
157160
assert include_dir.exists(), f"Expected {include_dir} to be present"
158161

@@ -182,7 +185,7 @@ def check_stable_only_symbols(install_root: Path) -> None:
182185
"-c", # Compile only, don't link
183186
]
184187

185-
# Compile WITHOUT any macros
188+
# Compile WITHOUT any macros - should succeed
186189
symbols_without = _compile_and_extract_symbols(
187190
cpp_content=test_cpp_content,
188191
compile_flags=base_compile_flags,
@@ -196,49 +199,56 @@ def check_stable_only_symbols(install_root: Path) -> None:
196199
"Expected a non-zero number of symbols without any macros"
197200
)
198201

199-
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
200-
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
202+
# Helper to verify compilation fails with expected error
203+
def _expect_compilation_failure(compile_flags: list[str], macro_name: str) -> None:
204+
with tempfile.TemporaryDirectory() as tmpdir:
205+
tmppath = Path(tmpdir)
206+
cpp_file = tmppath / "test.cpp"
207+
obj_file = tmppath / "test.o"
208+
209+
cpp_file.write_text(test_cpp_content)
210+
211+
result = subprocess.run(
212+
compile_flags + [str(cpp_file), "-o", str(obj_file)],
213+
capture_output=True,
214+
text=True,
215+
timeout=60,
216+
)
217+
218+
if result.returncode == 0:
219+
raise RuntimeError(
220+
f"Expected compilation to fail with {macro_name} defined, but it succeeded"
221+
)
222+
223+
stderr = result.stderr
224+
expected_error_msg = (
225+
"This file should not be included when either TORCH_STABLE_ONLY "
226+
"or TORCH_TARGET_VERSION is defined."
227+
)
228+
229+
if expected_error_msg not in stderr:
230+
raise RuntimeError(
231+
f"Expected error message to contain:\n '{expected_error_msg}'\n"
232+
f"but got:\n{stderr[:1000]}"
233+
)
234+
235+
print(f"Compilation correctly failed with {macro_name} defined")
201236

202-
symbols_with_stable_only = _compile_and_extract_symbols(
203-
cpp_content=test_cpp_content,
204-
compile_flags=compile_flags_with_stable_only,
205-
)
206-
207-
num_symbols_with_stable_only = len(symbols_with_stable_only)
208-
assert num_symbols_with_stable_only == 0, (
209-
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
210-
)
237+
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
238+
_expect_compilation_failure(compile_flags_with_stable_only, "TORCH_STABLE_ONLY")
211239

212-
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
213240
compile_flags_with_target_version = base_compile_flags + [
214241
"-DTORCH_TARGET_VERSION=1"
215242
]
216-
217-
symbols_with_target_version = _compile_and_extract_symbols(
218-
cpp_content=test_cpp_content,
219-
compile_flags=compile_flags_with_target_version,
243+
_expect_compilation_failure(
244+
compile_flags_with_target_version, "TORCH_TARGET_VERSION"
220245
)
221246

222-
num_symbols_with_target_version = len(symbols_with_target_version)
223-
assert num_symbols_with_target_version == 0, (
224-
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
225-
)
226-
227-
# Compile WITH both macros (expect 0 symbols)
228247
compile_flags_with_both = base_compile_flags + [
229248
"-DTORCH_STABLE_ONLY",
230249
"-DTORCH_TARGET_VERSION=1",
231250
]
232-
233-
symbols_with_both = _compile_and_extract_symbols(
234-
cpp_content=test_cpp_content,
235-
compile_flags=compile_flags_with_both,
236-
)
237-
238-
num_symbols_with_both = len(symbols_with_both)
239-
assert num_symbols_with_both == 0, (
240-
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
241-
)
251+
_expect_compilation_failure(compile_flags_with_both, "both macros")
242252

243253

244254
def check_stable_api_symbols(install_root: Path) -> None:

setup.py

Lines changed: 63 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,60 @@ def check_pydep(importname: str, module: str) -> None:
10891089

10901090

10911091
class build_ext(setuptools.command.build_ext.build_ext):
1092+
def _wrap_headers_with_macro(self, include_dir: Path) -> None:
1093+
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
1094+
1095+
Excludes:
1096+
- torch/headeronly/*
1097+
- torch/csrc/stable/*
1098+
- torch/csrc/inductor/aoti_torch/c/ (only shim headers)
1099+
- torch/csrc/inductor/aoti_torch/generated/
1100+
1101+
This method is idempotent - it will not wrap headers that are already wrapped.
1102+
"""
1103+
header_extensions = (".h", ".hpp", ".cuh")
1104+
header_files = [
1105+
f for ext in header_extensions for f in include_dir.rglob(f"*{ext}")
1106+
]
1107+
1108+
# Paths to exclude from wrapping (relative to include_dir)
1109+
exclude_dir_patterns = [
1110+
"torch/headeronly/",
1111+
"torch/csrc/stable/",
1112+
"torch/csrc/inductor/aoti_torch/c/",
1113+
"torch/csrc/inductor/aoti_torch/generated/",
1114+
]
1115+
1116+
# Marker to detect if a header is already wrapped
1117+
wrap_start_marker = (
1118+
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
1119+
)
1120+
1121+
for header_file in header_files:
1122+
rel_path = header_file.relative_to(include_dir).as_posix()
1123+
1124+
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
1125+
report(f"Skipping header: {rel_path}")
1126+
continue
1127+
1128+
original_content = header_file.read_text(encoding="utf-8")
1129+
1130+
# Check if already wrapped (idempotency check)
1131+
if original_content.startswith(wrap_start_marker):
1132+
report(f"Already wrapped, skipping: {rel_path}")
1133+
continue
1134+
1135+
wrapped_content = (
1136+
wrap_start_marker
1137+
+ f"{original_content}"
1138+
+ "\n#else\n"
1139+
+ '#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."\n'
1140+
+ "#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
1141+
)
1142+
1143+
header_file.write_text(wrapped_content, encoding="utf-8")
1144+
report(f"Wrapped header: {rel_path}")
1145+
10921146
def _embed_libomp(self) -> None:
10931147
# Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS
10941148
build_lib = Path(self.build_lib)
@@ -1256,6 +1310,15 @@ def run(self) -> None:
12561310

12571311
super().run()
12581312

1313+
# Wrap headers with TORCH_STABLE_ONLY and TORCH_TARGET_VERSION guards
1314+
build_lib = Path(self.build_lib)
1315+
build_torch_include_dir = build_lib / "torch" / "include"
1316+
if build_torch_include_dir.exists():
1317+
report(
1318+
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
1319+
)
1320+
self._wrap_headers_with_macro(build_torch_include_dir)
1321+
12591322
if IS_DARWIN:
12601323
self._embed_libomp()
12611324

@@ -1358,45 +1421,6 @@ def __exit__(self, *exc_info: object) -> None:
13581421

13591422
# Need to create the proper LICENSE.txt for the wheel
13601423
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
1361-
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
1362-
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
1363-
1364-
Excludes:
1365-
- torch/include/torch/headeronly/*
1366-
- torch/include/torch/csrc/stable/*
1367-
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
1368-
- torch/include/torch/csrc/inductor/aoti_torch/generated/
1369-
"""
1370-
header_extensions = (".h", ".hpp", ".cuh")
1371-
header_files = [
1372-
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
1373-
]
1374-
1375-
# Paths to exclude from wrapping
1376-
exclude_dir_patterns = [
1377-
"torch/include/torch/headeronly/",
1378-
"torch/include/torch/csrc/stable/",
1379-
"torch/include/torch/csrc/inductor/aoti_torch/c/",
1380-
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
1381-
]
1382-
1383-
for header_file in header_files:
1384-
rel_path = header_file.relative_to(bdist_dir).as_posix()
1385-
1386-
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
1387-
report(f"Skipping header: {rel_path}")
1388-
continue
1389-
1390-
original_content = header_file.read_text(encoding="utf-8")
1391-
wrapped_content = (
1392-
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
1393-
f"{original_content}"
1394-
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
1395-
)
1396-
1397-
header_file.write_text(wrapped_content, encoding="utf-8")
1398-
report(f"Wrapped header: {rel_path}")
1399-
14001424
def run(self) -> None:
14011425
with concat_license_files(include_files=True):
14021426
super().run()
@@ -1419,14 +1443,6 @@ def write_wheelfile(self, *args: Any, **kwargs: Any) -> None:
14191443
# need an __init__.py file otherwise we wouldn't have a package
14201444
(bdist_dir / "torch" / "__init__.py").touch()
14211445

1422-
# Wrap all header files with TORCH_STABLE_ONLY macro
1423-
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
1424-
bdist_dir = Path(self.bdist_dir)
1425-
report(
1426-
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
1427-
)
1428-
self._wrap_headers_with_macro(bdist_dir)
1429-
14301446

14311447
class clean(Command):
14321448
user_options: ClassVar[list[tuple[str, str | None, str]]] = []

0 commit comments

Comments
 (0)