Skip to content

Commit fecbea3

Browse files
committed
Allow pre-commit hook to sync c10 headers between ET and pytorch
1 parent bf7d755 commit fecbea3

File tree

2 files changed

+161
-22
lines changed

2 files changed

+161
-22
lines changed

.githooks/pre-commit

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,29 @@
11
#!/usr/bin/env bash
22

3-
# Pre-commit hook to automatically update PyTorch commit pin when torch_pin.py changes
3+
# Pre-commit hook to automatically update PyTorch commit pin and sync c10 directories
44

55
# Check if torch_pin.py is being committed
66
if git diff --cached --name-only | grep -q "^torch_pin.py$"; then
77
echo "🔍 Detected changes to torch_pin.py"
8-
echo "📝 Updating PyTorch commit pin..."
8+
echo "📝 Updating PyTorch commit pin and syncing c10 directories..."
99

10-
# Run the update script
11-
hook_output=$(python .github/scripts/update_pytorch_pin.py 2>&1)
12-
hook_status=$?
13-
echo "$hook_output"
14-
15-
if [ $hook_status -eq 0 ]; then
16-
# Check if pytorch.txt was modified
10+
# Run the update script (which now also syncs c10 directories)
11+
if python .github/scripts/update_pytorch_pin.py; then
12+
# Stage any modified files (pytorch.txt and grafted c10 files)
1713
if ! git diff --quiet .ci/docker/ci_commit_pins/pytorch.txt; then
18-
echo "✅ PyTorch commit pin updated successfully"
19-
# Stage the updated file
2014
git add .ci/docker/ci_commit_pins/pytorch.txt
2115
echo "📌 Staged .ci/docker/ci_commit_pins/pytorch.txt"
22-
else
23-
echo "ℹ️ PyTorch commit pin unchanged"
2416
fi
25-
else
26-
if echo "$hook_output" | grep -qi "rate limit exceeded"; then
27-
echo "⚠️ PyTorch commit pin not updated due to GitHub API rate limiting."
28-
echo " Please manually update .ci/docker/ci_commit_pins/pytorch.txt if needed."
29-
else
30-
echo "❌ Failed to update PyTorch commit pin"
31-
echo "Please run: python .github/scripts/update_pytorch_pin.py"
32-
exit 1
17+
18+
# Stage any grafted c10 files
19+
if ! git diff --quiet runtime/core/portable_type/c10/; then
20+
git add runtime/core/portable_type/c10/
21+
echo "📌 Staged grafted c10 files"
3322
fi
23+
else
24+
echo "❌ Failed to update PyTorch commit pin"
25+
echo "Please run: python .github/scripts/update_pytorch_pin.py"
26+
exit 1
3427
fi
3528
fi
3629

.github/scripts/update_pytorch_pin.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
#!/usr/bin/env python3
22

3+
import base64
4+
import hashlib
35
import json
46
import re
57
import sys
68
import urllib.request
9+
from pathlib import Path
710

811

912
def parse_nightly_version(nightly_version):
@@ -101,6 +104,144 @@ def update_pytorch_pin(commit_hash):
101104
print(f"Updated {pin_file} with commit hash: {commit_hash}")
102105

103106

107+
def should_skip_file(filename):
108+
"""
109+
Check if a file should be skipped during sync (build files).
110+
111+
Args:
112+
filename: Base filename to check
113+
114+
Returns:
115+
True if file should be skipped
116+
"""
117+
skip_files = {"BUCK", "CMakeLists.txt", "TARGETS", "targets.bzl"}
118+
return filename in skip_files
119+
120+
121+
def fetch_file_content(commit_hash, file_path):
122+
"""
123+
Fetch file content from GitHub API.
124+
125+
Args:
126+
commit_hash: Commit hash to fetch from
127+
file_path: File path in the repository
128+
129+
Returns:
130+
File content as bytes
131+
"""
132+
api_url = f"https://api.github.com/repos/pytorch/pytorch/contents/{file_path}?ref={commit_hash}"
133+
134+
req = urllib.request.Request(api_url)
135+
req.add_header("Accept", "application/vnd.github.v3+json")
136+
req.add_header("User-Agent", "ExecuTorch-Bot")
137+
138+
try:
139+
with urllib.request.urlopen(req) as response:
140+
data = json.loads(response.read().decode())
141+
# Content is base64 encoded
142+
content = base64.b64decode(data["content"])
143+
return content
144+
except urllib.request.HTTPError as e:
145+
print(f"Error fetching file {file_path}: {e}", file=sys.stderr)
146+
raise
147+
148+
149+
def sync_directory(et_dir, pt_path, commit_hash):
150+
"""
151+
Sync files from PyTorch to ExecuTorch using GitHub API.
152+
Only syncs files that already exist in ExecuTorch - does not add new files.
153+
154+
Args:
155+
et_dir: ExecuTorch directory path
156+
pt_path: PyTorch directory path in the repository (e.g., "c10")
157+
commit_hash: Commit hash to fetch from
158+
159+
Returns:
160+
Number of files grafted
161+
"""
162+
files_grafted = 0
163+
print(f"Checking {et_dir} vs pytorch/{pt_path}...")
164+
165+
if not et_dir.exists():
166+
print(f"Warning: ExecuTorch directory {et_dir} does not exist, skipping")
167+
return 0
168+
169+
# Loop through files in ExecuTorch directory
170+
for et_file in et_dir.rglob("*"):
171+
if not et_file.is_file():
172+
continue
173+
174+
# Skip build files
175+
if should_skip_file(et_file.name):
176+
continue
177+
178+
# Construct corresponding path in PyTorch
179+
rel_path = et_file.relative_to(et_dir)
180+
pt_file_path = f"{pt_path}/{rel_path}".replace("\\", "/")
181+
182+
# Fetch content from PyTorch and compare
183+
try:
184+
pt_content = fetch_file_content(commit_hash, pt_file_path)
185+
et_content = et_file.read_bytes()
186+
187+
if pt_content != et_content:
188+
print(f"⚠️ Difference detected in {rel_path}")
189+
print(f"📋 Grafting from PyTorch commit {commit_hash}...")
190+
191+
et_file.write_bytes(pt_content)
192+
print(f"✅ Grafted {et_file}")
193+
files_grafted += 1
194+
except urllib.request.HTTPError as e:
195+
if e.code != 404: # It's ok to have more files in ET than pytorch/pytorch.
196+
print(f"Error fetching {rel_path} from PyTorch: {e}")
197+
except Exception as e:
198+
print(f"Error syncing {rel_path}: {e}")
199+
continue
200+
201+
return files_grafted
202+
203+
204+
def sync_c10_directories(commit_hash):
205+
"""
206+
Sync c10 and torch/headeronly directories from PyTorch to ExecuTorch using GitHub API.
207+
208+
Args:
209+
commit_hash: PyTorch commit hash to sync from
210+
211+
Returns:
212+
Total number of files grafted
213+
"""
214+
print("\n🔄 Syncing c10 directories from PyTorch via GitHub API...")
215+
216+
# Get repository root
217+
repo_root = Path.cwd()
218+
219+
# Define directory pairs to sync (from check_c10_sync.sh)
220+
# Format: (executorch_dir, pytorch_path_in_repo)
221+
dir_pairs = [
222+
(
223+
repo_root / "runtime/core/portable_type/c10/c10",
224+
"c10",
225+
),
226+
(
227+
repo_root / "runtime/core/portable_type/c10/torch/headeronly",
228+
"torch/headeronly",
229+
),
230+
]
231+
232+
total_grafted = 0
233+
for et_dir, pt_path in dir_pairs:
234+
files_grafted = sync_directory(et_dir, pt_path, commit_hash)
235+
total_grafted += files_grafted
236+
237+
if total_grafted > 0:
238+
print(f"\n✅ Successfully grafted {total_grafted} file(s) from PyTorch")
239+
else:
240+
print("\n✅ No differences found - c10 is in sync")
241+
242+
return total_grafted
243+
244+
104245
def main():
105246
try:
106247
# Read NIGHTLY_VERSION from torch_pin.py
@@ -118,7 +259,12 @@ def main():
118259
# Update the pin file
119260
update_pytorch_pin(commit_hash)
120261

121-
print("Successfully updated PyTorch commit pin!")
262+
# Sync c10 directories from PyTorch
263+
sync_c10_directories(commit_hash)
264+
265+
print(
266+
"\n✅ Successfully updated PyTorch commit pin and synced c10 directories!"
267+
)
122268

123269
except Exception as e:
124270
print(f"Error: {e}", file=sys.stderr)

0 commit comments

Comments
 (0)