Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions .githooks/pre-commit
Original file line number Diff line number Diff line change
@@ -1,36 +1,29 @@
#!/usr/bin/env bash

# Pre-commit hook to automatically update PyTorch commit pin when torch_pin.py changes
# Pre-commit hook to automatically update PyTorch commit pin and sync c10 directories when torch_pin.py changes

# Check if torch_pin.py is being committed
if git diff --cached --name-only | grep -q "^torch_pin.py$"; then
echo "🔍 Detected changes to torch_pin.py"
echo "📝 Updating PyTorch commit pin..."
echo "📝 Updating PyTorch commit pin and syncing c10 directories..."

# Run the update script
hook_output=$(python .github/scripts/update_pytorch_pin.py 2>&1)
hook_status=$?
echo "$hook_output"

if [ $hook_status -eq 0 ]; then
# Check if pytorch.txt was modified
# Run the update script (which now also syncs c10 directories)
if python .github/scripts/update_pytorch_pin.py; then
# Stage any modified files (pytorch.txt and grafted c10 files)
if ! git diff --quiet .ci/docker/ci_commit_pins/pytorch.txt; then
echo "✅ PyTorch commit pin updated successfully"
# Stage the updated file
git add .ci/docker/ci_commit_pins/pytorch.txt
echo "📌 Staged .ci/docker/ci_commit_pins/pytorch.txt"
else
echo "ℹ️ PyTorch commit pin unchanged"
fi
else
if echo "$hook_output" | grep -qi "rate limit exceeded"; then
echo "⚠️ PyTorch commit pin not updated due to GitHub API rate limiting."
echo " Please manually update .ci/docker/ci_commit_pins/pytorch.txt if needed."
else
echo "❌ Failed to update PyTorch commit pin"
echo "Please run: python .github/scripts/update_pytorch_pin.py"
exit 1

# Stage any grafted c10 files
if ! git diff --quiet runtime/core/portable_type/c10/; then
git add runtime/core/portable_type/c10/
echo "📌 Staged grafted c10 files"
fi
else
echo "❌ Failed to update PyTorch commit pin"
echo "Please run: python .github/scripts/update_pytorch_pin.py"
exit 1
fi
fi

Expand Down
148 changes: 147 additions & 1 deletion .github/scripts/update_pytorch_pin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#!/usr/bin/env python3

import base64
import hashlib
import json
import re
import sys
import urllib.request
from pathlib import Path


def parse_nightly_version(nightly_version):
Expand Down Expand Up @@ -101,6 +104,144 @@ def update_pytorch_pin(commit_hash):
print(f"Updated {pin_file} with commit hash: {commit_hash}")


def should_skip_file(filename):
"""
Check if a file should be skipped during sync (build files).

Args:
filename: Base filename to check

Returns:
True if file should be skipped
"""
skip_files = {"BUCK", "CMakeLists.txt", "TARGETS", "targets.bzl"}
return filename in skip_files


def fetch_file_content(commit_hash, file_path):
"""
Fetch file content from GitHub API.

Args:
commit_hash: Commit hash to fetch from
file_path: File path in the repository

Returns:
File content as bytes
"""
api_url = f"https://api.github.com/repos/pytorch/pytorch/contents/{file_path}?ref={commit_hash}"

req = urllib.request.Request(api_url)
req.add_header("Accept", "application/vnd.github.v3+json")
req.add_header("User-Agent", "ExecuTorch-Bot")

try:
with urllib.request.urlopen(req) as response:
data = json.loads(response.read().decode())
# Content is base64 encoded
content = base64.b64decode(data["content"])
return content
except urllib.request.HTTPError as e:
print(f"Error fetching file {file_path}: {e}", file=sys.stderr)
raise


def sync_directory(et_dir, pt_path, commit_hash):
"""
Sync files from PyTorch to ExecuTorch using GitHub API.
Only syncs files that already exist in ExecuTorch - does not add new files.

Args:
et_dir: ExecuTorch directory path
pt_path: PyTorch directory path in the repository (e.g., "c10")
commit_hash: Commit hash to fetch from

Returns:
Number of files grafted
"""
files_grafted = 0
print(f"Checking {et_dir} vs pytorch/{pt_path}...")

if not et_dir.exists():
print(f"Warning: ExecuTorch directory {et_dir} does not exist, skipping")
return 0

# Loop through files in ExecuTorch directory
for et_file in et_dir.rglob("*"):
if not et_file.is_file():
continue

# Skip build files
if should_skip_file(et_file.name):
continue

# Construct corresponding path in PyTorch
rel_path = et_file.relative_to(et_dir)
pt_file_path = f"{pt_path}/{rel_path}".replace("\\", "/")

# Fetch content from PyTorch and compare
try:
pt_content = fetch_file_content(commit_hash, pt_file_path)
et_content = et_file.read_bytes()

if pt_content != et_content:
print(f"⚠️ Difference detected in {rel_path}")
print(f"📋 Grafting from PyTorch commit {commit_hash}...")

et_file.write_bytes(pt_content)
print(f"✅ Grafted {et_file}")
files_grafted += 1
except urllib.request.HTTPError as e:
if e.code != 404: # It's ok to have more files in ET than pytorch/pytorch.
print(f"Error fetching {rel_path} from PyTorch: {e}")
except Exception as e:
print(f"Error syncing {rel_path}: {e}")
continue

return files_grafted


def sync_c10_directories(commit_hash):
"""
Sync c10 and torch/headeronly directories from PyTorch to ExecuTorch using GitHub API.

Args:
commit_hash: PyTorch commit hash to sync from

Returns:
Total number of files grafted
"""
print("\n🔄 Syncing c10 directories from PyTorch via GitHub API...")

# Get repository root
repo_root = Path.cwd()

# Define directory pairs to sync (from check_c10_sync.sh)
# Format: (executorch_dir, pytorch_path_in_repo)
dir_pairs = [
(
repo_root / "runtime/core/portable_type/c10/c10",
"c10",
),
(
repo_root / "runtime/core/portable_type/c10/torch/headeronly",
"torch/headeronly",
),
]

total_grafted = 0
for et_dir, pt_path in dir_pairs:
files_grafted = sync_directory(et_dir, pt_path, commit_hash)
total_grafted += files_grafted

if total_grafted > 0:
print(f"\n✅ Successfully grafted {total_grafted} file(s) from PyTorch")
else:
print("\n✅ No differences found - c10 is in sync")

return total_grafted


def main():
try:
# Read NIGHTLY_VERSION from torch_pin.py
Expand All @@ -118,7 +259,12 @@ def main():
# Update the pin file
update_pytorch_pin(commit_hash)

print("Successfully updated PyTorch commit pin!")
# Sync c10 directories from PyTorch
sync_c10_directories(commit_hash)

print(
"\n✅ Successfully updated PyTorch commit pin and synced c10 directories!"
)

except Exception as e:
print(f"Error: {e}", file=sys.stderr)
Expand Down
Loading