diff --git a/.githooks/pre-commit b/.githooks/pre-commit index b3d66f3f69a..f29342da67e 100755 --- a/.githooks/pre-commit +++ b/.githooks/pre-commit @@ -1,36 +1,29 @@ #!/usr/bin/env bash -# Pre-commit hook to automatically update PyTorch commit pin when torch_pin.py changes +# Pre-commit hook to automatically update PyTorch commit pin and sync c10 directories when torch_pin.py changes # Check if torch_pin.py is being committed if git diff --cached --name-only | grep -q "^torch_pin.py$"; then echo "šŸ” Detected changes to torch_pin.py" - echo "šŸ“ Updating PyTorch commit pin..." + echo "šŸ“ Updating PyTorch commit pin and syncing c10 directories..." - # Run the update script - hook_output=$(python .github/scripts/update_pytorch_pin.py 2>&1) - hook_status=$? - echo "$hook_output" - - if [ $hook_status -eq 0 ]; then - # Check if pytorch.txt was modified + # Run the update script (which now also syncs c10 directories) + if python .github/scripts/update_pytorch_pin.py; then + # Stage any modified files (pytorch.txt and grafted c10 files) if ! git diff --quiet .ci/docker/ci_commit_pins/pytorch.txt; then - echo "āœ… PyTorch commit pin updated successfully" - # Stage the updated file git add .ci/docker/ci_commit_pins/pytorch.txt echo "šŸ“Œ Staged .ci/docker/ci_commit_pins/pytorch.txt" - else - echo "ā„¹ļø PyTorch commit pin unchanged" fi - else - if echo "$hook_output" | grep -qi "rate limit exceeded"; then - echo "āš ļø PyTorch commit pin not updated due to GitHub API rate limiting." - echo " Please manually update .ci/docker/ci_commit_pins/pytorch.txt if needed." - else - echo "āŒ Failed to update PyTorch commit pin" - echo "Please run: python .github/scripts/update_pytorch_pin.py" - exit 1 + + # Stage any grafted c10 files + if ! git diff --quiet runtime/core/portable_type/c10/; then + git add runtime/core/portable_type/c10/ + echo "šŸ“Œ Staged grafted c10 files" fi + else + echo "āŒ Failed to update PyTorch commit pin" + echo "Please run: python .github/scripts/update_pytorch_pin.py" + exit 1 fi fi diff --git a/.github/scripts/update_pytorch_pin.py b/.github/scripts/update_pytorch_pin.py index 2df0eb8d5a1..dbc48552d9b 100644 --- a/.github/scripts/update_pytorch_pin.py +++ b/.github/scripts/update_pytorch_pin.py @@ -1,9 +1,12 @@ #!/usr/bin/env python3 +import base64 +import hashlib import json import re import sys import urllib.request +from pathlib import Path def parse_nightly_version(nightly_version): @@ -101,6 +104,144 @@ def update_pytorch_pin(commit_hash): print(f"Updated {pin_file} with commit hash: {commit_hash}") +def should_skip_file(filename): + """ + Check if a file should be skipped during sync (build files). + + Args: + filename: Base filename to check + + Returns: + True if file should be skipped + """ + skip_files = {"BUCK", "CMakeLists.txt", "TARGETS", "targets.bzl"} + return filename in skip_files + + +def fetch_file_content(commit_hash, file_path): + """ + Fetch file content from GitHub API. + + Args: + commit_hash: Commit hash to fetch from + file_path: File path in the repository + + Returns: + File content as bytes + """ + api_url = f"https://api.github.com/repos/pytorch/pytorch/contents/{file_path}?ref={commit_hash}" + + req = urllib.request.Request(api_url) + req.add_header("Accept", "application/vnd.github.v3+json") + req.add_header("User-Agent", "ExecuTorch-Bot") + + try: + with urllib.request.urlopen(req) as response: + data = json.loads(response.read().decode()) + # Content is base64 encoded + content = base64.b64decode(data["content"]) + return content + except urllib.request.HTTPError as e: + print(f"Error fetching file {file_path}: {e}", file=sys.stderr) + raise + + +def sync_directory(et_dir, pt_path, commit_hash): + """ + Sync files from PyTorch to ExecuTorch using GitHub API. + Only syncs files that already exist in ExecuTorch - does not add new files. + + Args: + et_dir: ExecuTorch directory path + pt_path: PyTorch directory path in the repository (e.g., "c10") + commit_hash: Commit hash to fetch from + + Returns: + Number of files grafted + """ + files_grafted = 0 + print(f"Checking {et_dir} vs pytorch/{pt_path}...") + + if not et_dir.exists(): + print(f"Warning: ExecuTorch directory {et_dir} does not exist, skipping") + return 0 + + # Loop through files in ExecuTorch directory + for et_file in et_dir.rglob("*"): + if not et_file.is_file(): + continue + + # Skip build files + if should_skip_file(et_file.name): + continue + + # Construct corresponding path in PyTorch + rel_path = et_file.relative_to(et_dir) + pt_file_path = f"{pt_path}/{rel_path}".replace("\\", "/") + + # Fetch content from PyTorch and compare + try: + pt_content = fetch_file_content(commit_hash, pt_file_path) + et_content = et_file.read_bytes() + + if pt_content != et_content: + print(f"āš ļø Difference detected in {rel_path}") + print(f"šŸ“‹ Grafting from PyTorch commit {commit_hash}...") + + et_file.write_bytes(pt_content) + print(f"āœ… Grafted {et_file}") + files_grafted += 1 + except urllib.request.HTTPError as e: + if e.code != 404: # It's ok to have more files in ET than pytorch/pytorch. + print(f"Error fetching {rel_path} from PyTorch: {e}") + except Exception as e: + print(f"Error syncing {rel_path}: {e}") + continue + + return files_grafted + + +def sync_c10_directories(commit_hash): + """ + Sync c10 and torch/headeronly directories from PyTorch to ExecuTorch using GitHub API. + + Args: + commit_hash: PyTorch commit hash to sync from + + Returns: + Total number of files grafted + """ + print("\nšŸ”„ Syncing c10 directories from PyTorch via GitHub API...") + + # Get repository root + repo_root = Path.cwd() + + # Define directory pairs to sync (from check_c10_sync.sh) + # Format: (executorch_dir, pytorch_path_in_repo) + dir_pairs = [ + ( + repo_root / "runtime/core/portable_type/c10/c10", + "c10", + ), + ( + repo_root / "runtime/core/portable_type/c10/torch/headeronly", + "torch/headeronly", + ), + ] + + total_grafted = 0 + for et_dir, pt_path in dir_pairs: + files_grafted = sync_directory(et_dir, pt_path, commit_hash) + total_grafted += files_grafted + + if total_grafted > 0: + print(f"\nāœ… Successfully grafted {total_grafted} file(s) from PyTorch") + else: + print("\nāœ… No differences found - c10 is in sync") + + return total_grafted + + def main(): try: # Read NIGHTLY_VERSION from torch_pin.py @@ -118,7 +259,12 @@ def main(): # Update the pin file update_pytorch_pin(commit_hash) - print("Successfully updated PyTorch commit pin!") + # Sync c10 directories from PyTorch + sync_c10_directories(commit_hash) + + print( + "\nāœ… Successfully updated PyTorch commit pin and synced c10 directories!" + ) except Exception as e: print(f"Error: {e}", file=sys.stderr)