Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 88 additions & 47 deletions actions/update_actions/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,90 @@

import sys
from pathlib import Path
from typing import Literal

from ruamel.yaml import YAML


def find_uses(obj) -> list[str]:
"""Recursively find all 'uses' values in a YAML structure."""
found = []
if isinstance(obj, dict):
if isinstance(obj.get("steps"), list):
for step in obj["steps"]:
if isinstance(step, dict) and isinstance(step.get("uses"), str):
found.append(step["uses"])
for value in obj.values():
found.extend(find_uses(value))
elif isinstance(obj, list):

# Guard: Handle lists
if isinstance(obj, list):
for item in obj:
found.extend(find_uses(item))
return found

# Guard: Only process dicts from here
if not isinstance(obj, dict):
return found

# Process steps if present
if isinstance(obj.get("steps"), list):
for step in obj["steps"]:
if not isinstance(step, dict):
continue
if isinstance(step.get("uses"), str):
found.append(step["uses"])

# Recurse into all dict values
for value in obj.values():
found.extend(find_uses(value))

return found


def get_granularity(version: str) -> Literal["major", "minor", "patch"]:
parts = version.split(".")
if len(parts) == 1:
return "major"

if len(parts) == 2:
return "minor"

if len(parts) >= 3:
return "patch"

return "patch"


def update_uses_in_structure(obj, upgrades: dict[tuple[str, str], str]) -> bool:
"""
Recursively update 'uses' values in a YAML structure.
Returns True if any updates were made.
"""
if not isinstance(obj, (dict, list)):
return False

updated = False
if isinstance(obj, dict):
if isinstance(obj.get("steps"), list):
for step in obj["steps"]:
if isinstance(step, dict) and isinstance(step.get("uses"), str):
use = step["uses"]
if "@" in use:
repo, tag = use.split("@", 1)
new_tag = upgrades.get((repo, tag))
if new_tag:
step["uses"] = f"{repo}@{new_tag}"
updated = True
for value in obj.values():
if update_uses_in_structure(value, upgrades):
updated = True
elif isinstance(obj, list):

if isinstance(obj, list):
for item in obj:
if update_uses_in_structure(item, upgrades):
updated = True
return updated

# obj is a dict
if isinstance(obj.get("steps"), list):
for step in obj["steps"]:
if not isinstance(step, dict) or not isinstance(step.get("uses"), str):
continue

use = step["uses"]
if "@" not in use:
continue

repo, tag = use.split("@", 1)
new_tag = upgrades.get((repo, tag))
if new_tag:
step["uses"] = f"{repo}@{new_tag}"
updated = True

for value in obj.values():
if update_uses_in_structure(value, upgrades):
updated = True

return updated


Expand Down Expand Up @@ -89,31 +129,27 @@ def apply_updates(text: str, upgrades: dict[tuple[str, str], str]) -> str:
lines = text.split("\n")

for i, line in enumerate(lines):
# Look for lines that contain 'uses:' with a value
# Handle both plain keys and list items with dashes
stripped = line.lstrip()

# Check if line has 'uses:' (either "uses:" or "- uses:")
# Guard: Skip if no 'uses:' found
if "uses:" not in stripped:
continue

# Find the position of 'uses:' in the line
# Guard: Find position of 'uses:'
uses_idx = stripped.find("uses:")
if uses_idx == -1:
continue

# Check if everything before 'uses:' is valid YAML (dash followed by spaces, or nothing)
# Guard: Validate prefix is either empty or a dash
prefix = stripped[:uses_idx].strip()
if prefix and prefix != "-":
continue

# Extract the indentation from the original line
# Extract indentation and value parts
indent = line[: len(line) - len(stripped)]
rest = stripped[uses_idx + 5 :].strip()

# Get the part after 'uses:'
rest = stripped[uses_idx + 5 :].strip() # Remove 'uses:' and leading whitespace

# Handle comments - extract value and any trailing comment
# Parse value and comment
comment = ""
value_part = rest
if "#" in rest:
Expand All @@ -124,19 +160,24 @@ def apply_updates(text: str, upgrades: dict[tuple[str, str], str]) -> str:
# Check if this value matches any upgrade
for (repo, current_tag), new_tag in upgrades.items():
old_value = f"{repo}@{current_tag}"
new_value = f"{repo}@{new_tag}"
if value_part == old_value:
# Reconstruct the line, preserving list item syntax if present
if stripped.startswith("- "):
if comment:
lines[i] = f"{indent}- uses: {new_value} {comment}"
else:
lines[i] = f"{indent}- uses: {new_value}"
else:
if comment:
lines[i] = f"{indent}uses: {new_value} {comment}"
else:
lines[i] = f"{indent}uses: {new_value}"
break
if value_part != old_value:
continue

# Granularize new_tag to match current_tag's granularity
granularity = get_granularity(current_tag)
if granularity == "major":
new_tag_granuralized = new_tag.split(".")[0]
elif granularity == "minor":
new_tag_granuralized = ".".join(new_tag.split(".")[:2])
else:
new_tag_granuralized = ".".join(new_tag.split(".")[:3])

new_value = f"{repo}@{new_tag_granuralized}"

# Reconstruct line with proper formatting
prefix_str = "- " if stripped.startswith("- ") else ""
comment_str = f" {comment}" if comment else ""
lines[i] = f"{indent}{prefix_str}uses: {new_value}{comment_str}"
break

return "\n".join(lines)
29 changes: 27 additions & 2 deletions actions/update_actions/tests/test_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,31 @@


class TestScanner(unittest.TestCase):
def test_get_granularity(self):
"""Test granularity detection for various version formats."""
test_cases = [
# (version, expected_granularity)
("1", "major"),
("2", "major"),
("v1", "major"),
("v10", "major"),
("1.2", "minor"),
("2.5", "minor"),
("v1.2", "minor"),
("v3.14", "minor"),
("1.2.3", "patch"),
("2.5.8", "patch"),
("v1.2.3", "patch"),
("v3.14.159", "patch"),
("1.2.3.4", "patch"), # More than 3 parts
("v1.2.3.4", "patch"),
]

for version, expected in test_cases:
with self.subTest(version=version):
result = scanner.get_granularity(version)
self.assertEqual(result, expected)

def test_find_uses_nested(self):
data = {
"jobs": {
Expand Down Expand Up @@ -111,8 +136,8 @@ def test_apply_updates_preserves_non_uses_variables(self):
updated = scanner.apply_updates(text, upgrades)

# Verify that the uses entries were updated
self.assertIn("actions/create-github-app-token@v2.2.1", updated)
self.assertIn("actions/setup-python@v6.2.0", updated)
self.assertIn("actions/create-github-app-token@v2", updated)
self.assertIn("actions/setup-python@v6", updated)

# Verify that non-uses variables are preserved exactly as-is
self.assertIn('PYTHON_VERSION: "3.14"', updated)
Expand Down
121 changes: 90 additions & 31 deletions actions/update_actions/tests/test_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,93 @@

class TestUpdater(unittest.TestCase):
def test_update_actions_writes_updates(self):
with tempfile.TemporaryDirectory() as tmpdir:
root = Path(tmpdir)
workflow_dir = root / ".github/workflows"
workflow_dir.mkdir(parents=True)
workflow = workflow_dir / "ci.yml"
workflow.write_text(
"""
jobs:
build:
steps:
- uses: actions/checkout@v3
""",
encoding="utf-8",
)
cases = [
{
"name": "major",
"before": """
jobs:
build:
steps:
- uses: actions/checkout@v3
""",
"after": """
jobs:
build:
steps:
- uses: actions/checkout@v4
""",
},
{
"name": "minor",
"before": """
jobs:
build:
steps:
- uses: actions/checkout@v3.0
""",
"after": """
jobs:
build:
steps:
- uses: actions/checkout@v4.1
""",
},
{
"name": "patch",
"before": """
jobs:
build:
steps:
- uses: actions/checkout@v3.0.1
""",
"after": """
jobs:
build:
steps:
- uses: actions/checkout@v4.1.0
""",
},
{
"name": "multiple",
"before": """
jobs:
build:
steps:
- uses: actions/checkout@v3.0.1
- uses: actions/checkout@v3
""",
"after": """
jobs:
build:
steps:
- uses: actions/checkout@v4.1.0
- uses: actions/checkout@v4
""",
},
]

with mock.patch(
"update_actions.updater.fetch_release_tags",
return_value=["v2", "v4"],
):
updater.update_actions(
root=root,
file_glob=".github/**/*.yml",
prefixes=["actions"],
dry_run=False,
)
for case in cases:
with self.subTest(case=case["name"]):
with tempfile.TemporaryDirectory() as tmpdir:
root = Path(tmpdir)
workflow_dir = root / ".github/workflows"
workflow_dir.mkdir(parents=True)
workflow = workflow_dir / "ci.yml"
workflow.write_text(case["before"], encoding="utf-8")

with mock.patch(
"update_actions.updater.fetch_release_tags",
return_value=["v2", "v4", "v4.1.0", "v4.1"],
):
updater.update_actions(
root=root,
file_glob=".github/**/*.yml",
prefixes=["actions"],
dry_run=False,
)

updated = workflow.read_text(encoding="utf-8")
self.assertIn("actions/checkout@v4", updated)
updated = workflow.read_text(encoding="utf-8")
self.assertEqual(updated, case["after"])

def test_update_actions_dry_run_no_write(self):
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -44,11 +103,11 @@ def test_update_actions_dry_run_no_write(self):
workflow_dir.mkdir(parents=True)
workflow = workflow_dir / "ci.yml"
original = """
jobs:
build:
steps:
- uses: actions/checkout@v3
"""
jobs:
build:
steps:
- uses: actions/checkout@v3
"""
workflow.write_text(original, encoding="utf-8")

with mock.patch(
Expand Down
Loading
Loading