From 4e2ff5c7addfee0c6e9bc28297ce9de10c0df8e6 Mon Sep 17 00:00:00 2001 From: Stephen Allen Date: Tue, 20 Jan 2026 14:48:06 -0600 Subject: [PATCH 1/2] feat: add upgrade command for updating projects to newer ASP versions --- agent_starter_pack/cli/commands/upgrade.py | 592 ++++++++++++++++++ agent_starter_pack/cli/main.py | 2 + agent_starter_pack/cli/utils/__init__.py | 2 + .../cli/utils/generation_metadata.py | 50 ++ agent_starter_pack/cli/utils/upgrade.py | 501 +++++++++++++++ tests/cli/commands/test_upgrade.py | 377 +++++++++++ tests/cli/utils/test_generation_metadata.py | 46 +- tests/cli/utils/test_upgrade_utils.py | 490 +++++++++++++++ 8 files changed, 2016 insertions(+), 44 deletions(-) create mode 100644 agent_starter_pack/cli/commands/upgrade.py create mode 100644 agent_starter_pack/cli/utils/generation_metadata.py create mode 100644 agent_starter_pack/cli/utils/upgrade.py create mode 100644 tests/cli/commands/test_upgrade.py create mode 100644 tests/cli/utils/test_upgrade_utils.py diff --git a/agent_starter_pack/cli/commands/upgrade.py b/agent_starter_pack/cli/commands/upgrade.py new file mode 100644 index 00000000..a3968809 --- /dev/null +++ b/agent_starter_pack/cli/commands/upgrade.py @@ -0,0 +1,592 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Upgrade command for upgrading existing projects to newer ASP versions.""" + +import difflib +import logging +import pathlib +import re +import shutil +import subprocess +import tempfile + +import click +from rich.console import Console +from rich.prompt import Prompt + +from ..utils.generation_metadata import metadata_to_cli_args +from ..utils.logging import handle_cli_error +from ..utils.upgrade import ( + DependencyChange, + FileCompareResult, + compare_all_files, + group_results_by_action, + merge_pyproject_dependencies, + write_merged_dependencies, +) +from ..utils.version import get_current_version +from .enhance import get_project_asp_config + +console = Console() + + +def _ensure_uvx_available() -> bool: + """Check if uvx is available.""" + try: + subprocess.run(["uvx", "--version"], capture_output=True, check=True) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + +def _run_create_command( + args: list[str], + output_dir: pathlib.Path, + project_name: str, + version: str | None = None, +) -> bool: + """Run the create command to generate a template. + + Args: + args: CLI arguments for create command + output_dir: Directory to output the template + project_name: Name for the project + version: Optional ASP version to use (uses uvx if specified) + + Returns: + True if successful, False otherwise + """ + # Build the command + if version: + cmd = ["uvx", f"agent-starter-pack@{version}", "create"] + else: + cmd = ["agent-starter-pack", "create"] + + cmd.extend([project_name]) + cmd.extend(["--output-dir", str(output_dir)]) + cmd.extend(["--auto-approve", "--skip-deps", "--skip-checks"]) + cmd.extend(args) + + logging.debug(f"Running command: {' '.join(cmd)}") + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + encoding="utf-8", + timeout=300, # 5 minute timeout + ) + + if result.returncode != 0: + logging.error(f"Command failed: {result.stderr}") + return False + + return True + except subprocess.TimeoutExpired: + logging.error("Command timed out") + return False + except Exception as e: + logging.error(f"Error running command: {e}") + return False + + +def _display_version_header(old_version: str, new_version: str) -> None: + """Display the upgrade version header.""" + console.print() + console.print(f"[bold blue]📦 Upgrading {old_version} → {new_version}[/bold blue]") + console.print() + + +def _display_results( + groups: dict[str, list[FileCompareResult]], + dep_changes: list[DependencyChange], + dry_run: bool = False, +) -> None: + """Display the upgrade results grouped by action.""" + if groups["auto_update"]: + console.print("[bold green]Auto-updating (unchanged by you):[/bold green]") + for result in groups["auto_update"]: + console.print(f" [green]✓[/green] {result.path}") + console.print() + + preserved_user_modified = [ + r for r in groups["preserve"] if r.preserve_type == "asp_unchanged" + ] + if preserved_user_modified: + console.print( + "[bold cyan]Preserving (you modified, ASP unchanged):[/bold cyan]" + ) + for result in preserved_user_modified: + console.print(f" [cyan]✓[/cyan] {result.path}") + console.print() + + skipped = [ + r for r in groups["skip"] if r.category in ("agent_code", "config_files") + ] + if skipped: + console.print("[dim]Skipping (your code):[/dim]") + for result in skipped: + console.print(f" [dim]-[/dim] {result.path}") + console.print() + + if groups["new"]: + console.print("[bold yellow]New files in ASP:[/bold yellow]") + for result in groups["new"]: + console.print(f" [yellow]+[/yellow] {result.path}") + console.print() + + if groups["removed"]: + console.print("[bold yellow]Removed in ASP:[/bold yellow]") + for result in groups["removed"]: + console.print(f" [yellow]-[/yellow] {result.path}") + console.print() + + if groups["conflict"]: + console.print("[bold red]Conflicts (both changed):[/bold red]") + for result in groups["conflict"]: + console.print(f" [red]⚠[/red] {result.path}") + if not dry_run: + console.print("[dim] You'll be prompted to resolve each conflict.[/dim]") + console.print() + + if dep_changes: + console.print("[bold]Dependencies:[/bold]") + for change in dep_changes: + if change.change_type == "updated": + console.print( + f" [green]✓[/green] Updated: {change.name} " + f"{change.old_version} → {change.new_version}" + ) + elif change.change_type == "added": + console.print( + f" [green]+[/green] Added: {change.name}{change.new_version}" + ) + elif change.change_type == "kept": + console.print( + f" [cyan]✓[/cyan] Kept: {change.name}{change.old_version}" + ) + elif change.change_type == "removed": + console.print( + f" [yellow]-[/yellow] Removed: {change.name}{change.old_version}" + ) + console.print() + + +def _handle_conflict( + result: FileCompareResult, + project_dir: pathlib.Path, + new_template_dir: pathlib.Path, + auto_approve: bool, +) -> str: + """Handle a file conflict interactively. + + Args: + result: The conflict result + project_dir: Path to current project + new_template_dir: Path to new template + auto_approve: If True, keep user's version + + Returns: + Action taken: "kept", "updated", or "skipped" + """ + if auto_approve: + console.print(f" [dim]Keeping your version: {result.path}[/dim]") + return "kept" + + console.print(f"\n[bold yellow]Conflict: {result.path}[/bold yellow]") + console.print(f" Reason: {result.reason}") + + choice = Prompt.ask( + " (v)iew diff, (k)eep yours, (u)se new, (s)kip", + choices=["v", "k", "u", "s"], + default="k", + ) + + if choice == "v": + # Show diff using Python's difflib (cross-platform) + current_file = project_dir / result.path + new_file = new_template_dir / result.path + + try: + current_lines = current_file.read_text(encoding="utf-8").splitlines( + keepends=True + ) + new_lines = new_file.read_text(encoding="utf-8").splitlines(keepends=True) + + diff_lines = list( + difflib.unified_diff( + current_lines, + new_lines, + fromfile=f"Your version: {result.path}", + tofile=f"New ASP version: {result.path}", + ) + ) + diff_output = "".join(diff_lines) + + console.print() + if diff_output: + # Limit output to ~2000 characters + if len(diff_output) > 2000: + console.print(diff_output[:2000]) + console.print("[dim]... (truncated)[/dim]") + else: + console.print(diff_output) + else: + console.print("[dim]No differences found[/dim]") + except Exception as e: + console.print(f"[red]Could not show diff: {e}[/red]") + + # Ask again after viewing + choice = Prompt.ask( + " (k)eep yours, (u)se new, (s)kip", + choices=["k", "u", "s"], + default="k", + ) + + if choice == "k": + console.print(" [cyan]Keeping your version[/cyan]") + return "kept" + elif choice == "u": + return "updated" + else: + return "skipped" + + +def _copy_file(src: pathlib.Path, dst: pathlib.Path) -> bool: + """Copy a file, creating parent directories as needed.""" + if not src.exists(): + return False + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src, dst) + return True + + +def _apply_changes( + groups: dict[str, list[FileCompareResult]], + project_dir: pathlib.Path, + new_template_dir: pathlib.Path, + auto_approve: bool, + dry_run: bool, +) -> dict[str, int]: + """Apply the upgrade changes to the project.""" + counts = { + "updated": 0, + "added": 0, + "removed": 0, + "skipped": 0, + "conflicts_kept": 0, + "conflicts_updated": 0, + } + + if dry_run: + console.print("[bold yellow]Dry run - no changes made[/bold yellow]") + return counts + + for result in groups["auto_update"]: + if _copy_file(new_template_dir / result.path, project_dir / result.path): + counts["updated"] += 1 + + for result in groups["new"]: + should_add = ( + auto_approve + or Prompt.ask( + f" Add new file {result.path}?", choices=["y", "n"], default="y" + ) + == "y" + ) + if should_add: + if _copy_file(new_template_dir / result.path, project_dir / result.path): + counts["added"] += 1 + else: + counts["skipped"] += 1 + + for result in groups["removed"]: + file_path = project_dir / result.path + should_remove = ( + auto_approve + or Prompt.ask( + f" Remove file {result.path}?", choices=["y", "n"], default="y" + ) + == "y" + ) + if should_remove and file_path.exists(): + file_path.unlink() + counts["removed"] += 1 + elif not should_remove: + counts["skipped"] += 1 + + if groups["conflict"]: + console.print() + console.print("[bold]Resolving conflicts:[/bold]") + + for result in groups["conflict"]: + action = _handle_conflict(result, project_dir, new_template_dir, auto_approve) + if action == "updated": + if _copy_file(new_template_dir / result.path, project_dir / result.path): + counts["conflicts_updated"] += 1 + elif action == "kept": + counts["conflicts_kept"] += 1 + else: + counts["skipped"] += 1 + + return counts + + +def _update_pyproject_metadata( + project_dir: pathlib.Path, + new_version: str, +) -> None: + """Update the asp_version in pyproject.toml. + + Args: + project_dir: Path to project directory + new_version: New ASP version + """ + pyproject_path = project_dir / "pyproject.toml" + if not pyproject_path.exists(): + return + + try: + content = pyproject_path.read_text(encoding="utf-8") + content = re.sub( + r'(asp_version\s*=\s*")[^"]*(")', + rf"\g<1>{new_version}\g<2>", + content, + ) + pyproject_path.write_text(content, encoding="utf-8") + except Exception as e: + logging.warning(f"Could not update asp_version: {e}") + + +@click.command() +@click.argument( + "project_path", + type=click.Path(exists=True, path_type=pathlib.Path), + default=".", + required=False, +) +@click.option( + "--dry-run", + is_flag=True, + help="Preview changes without applying them", +) +@click.option( + "--auto-approve", + "-y", + is_flag=True, + help="Auto-apply non-conflicting changes without prompts", +) +@click.option( + "--debug", + is_flag=True, + help="Enable debug logging", +) +@handle_cli_error +def upgrade( + project_path: pathlib.Path, + dry_run: bool, + auto_approve: bool, + debug: bool, +) -> None: + """Upgrade project to newer agent-starter-pack version. + + Compares your project against both old and new template versions to + intelligently apply updates while preserving your customizations. + + Uses 3-way comparison: + - If you didn't modify a file, it's auto-updated + - If ASP didn't change a file, your modifications are preserved + - If both changed, you're prompted to resolve the conflict + """ + if debug: + logging.basicConfig(level=logging.DEBUG, force=True) + console.print("[dim]Debug mode enabled[/dim]") + + # Resolve project path + project_dir = project_path.resolve() + + metadata = get_project_asp_config(project_dir) + if not metadata: + console.print( + "[bold red]Error:[/bold red] No agent-starter-pack metadata found." + ) + console.print("Ensure pyproject.toml has \\[tool.agent-starter-pack] section.") + raise SystemExit(1) + + old_version = metadata.get("asp_version") + if not old_version: + console.print( + "[bold red]Error:[/bold red] No asp_version found in project metadata." + ) + console.print( + "The project metadata is missing the version. " + "Please ensure pyproject.toml has asp_version in \\[tool.agent-starter-pack]." + ) + raise SystemExit(1) + + new_version = get_current_version() + + # Check if upgrade is needed + if old_version == new_version: + console.print( + f"[bold green]✅[/bold green] Project is already at version {new_version}" + ) + return + + # Check if uvx is available for re-templating old version + if not _ensure_uvx_available(): + console.print( + "[bold red]Error:[/bold red] 'uvx' is required for upgrade but not installed." + ) + console.print( + "[dim]Install uv to enable upgrade: curl -LsSf https://astral.sh/uv/install.sh | sh[/dim]" + ) + raise SystemExit(1) + + _display_version_header(old_version, new_version) + + # Get project name and CLI args from metadata + project_name = metadata.get("name", project_dir.name) + agent_directory = metadata.get("agent_directory", "app") + cli_args = metadata_to_cli_args(metadata) + + # Create temp directories for re-templating + temp_base = pathlib.Path(tempfile.mkdtemp(prefix="asp_upgrade_")) + old_template_dir = temp_base / "old" + new_template_dir = temp_base / "new" + + try: + console.print("[dim]Generating template versions for comparison...[/dim]") + + # Re-template old version + console.print(f"[dim] - Old template (v{old_version})...[/dim]") + if not _run_create_command( + cli_args, old_template_dir, project_name, old_version + ): + console.print( + f"[bold red]Error:[/bold red] Failed to generate old template (v{old_version})" + ) + console.print( + "[dim]This version may not be available. Try upgrading from a more recent version.[/dim]" + ) + raise SystemExit(1) + + # Re-template new version + console.print(f"[dim] - New template (v{new_version})...[/dim]") + if not _run_create_command(cli_args, new_template_dir, project_name): + console.print( + f"[bold red]Error:[/bold red] Failed to generate new template (v{new_version})" + ) + raise SystemExit(1) + + # The templates are created in subdirectories named after the project + old_template_project = old_template_dir / project_name + new_template_project = new_template_dir / project_name + + console.print() + + # Compare all files + console.print("[dim]Comparing files...[/dim]") + results = compare_all_files( + project_dir, + old_template_project, + new_template_project, + agent_directory, + ) + + # Group by action + groups = group_results_by_action(results) + + # Handle dependency merging + dep_result = merge_pyproject_dependencies( + project_dir / "pyproject.toml", + old_template_project / "pyproject.toml", + new_template_project / "pyproject.toml", + ) + + console.print() + + # Display results + _display_results(groups, dep_result.changes, dry_run) + + # Check if there's anything to do + total_changes = ( + len(groups["auto_update"]) + + len(groups["new"]) + + len(groups["removed"]) + + len(groups["conflict"]) + ) + + if total_changes == 0 and not dep_result.changes: + console.print("[bold green]✅[/bold green] No changes needed!") + return + + # Confirm before applying + if not auto_approve and not dry_run: + prompt_text = "\nProceed with upgrade?" + if groups["conflict"]: + prompt_text = "\nProceed? (you'll resolve conflicts next)" + proceed = Prompt.ask( + prompt_text, + choices=["y", "n"], + default="y", + ) + if proceed != "y": + console.print("[yellow]Upgrade cancelled.[/yellow]") + return + + # Apply changes + counts = _apply_changes( + groups, + project_dir, + new_template_project, + auto_approve, + dry_run, + ) + + # Apply dependency changes + if not dry_run and dep_result.changes: + write_merged_dependencies( + project_dir / "pyproject.toml", + dep_result.merged_deps, + ) + + # Update metadata version + if not dry_run: + _update_pyproject_metadata(project_dir, new_version) + + # Summary + console.print() + if dry_run: + console.print( + "[bold yellow]Dry run complete.[/bold yellow] " + "Run without --dry-run to apply changes." + ) + else: + console.print(f" Updated: {counts['updated']} files") + console.print(f" Added: {counts['added']} files") + console.print(f" Removed: {counts['removed']} files") + if counts["conflicts_kept"] or counts["conflicts_updated"]: + console.print( + f" Conflicts: {counts['conflicts_updated']} updated, " + f"{counts['conflicts_kept']} kept yours" + ) + console.print() + console.print("[bold green]✅ Upgrade complete![/bold green]") + + finally: + # Cleanup temp directories + shutil.rmtree(temp_base, ignore_errors=True) diff --git a/agent_starter_pack/cli/main.py b/agent_starter_pack/cli/main.py index 42910984..f0c4040d 100644 --- a/agent_starter_pack/cli/main.py +++ b/agent_starter_pack/cli/main.py @@ -24,6 +24,7 @@ from .commands.list import list_agents from .commands.register_gemini_enterprise import register_gemini_enterprise from .commands.setup_cicd import setup_cicd +from .commands.upgrade import upgrade from .utils import display_update_message console = Console() @@ -62,6 +63,7 @@ def cli() -> None: cli.add_command(extract) cli.add_command(register_gemini_enterprise) cli.add_command(setup_cicd) +cli.add_command(upgrade) cli.add_command(list_agents, name="list") diff --git a/agent_starter_pack/cli/utils/__init__.py b/agent_starter_pack/cli/utils/__init__.py index 6bbfc8ec..d2040ce7 100644 --- a/agent_starter_pack/cli/utils/__init__.py +++ b/agent_starter_pack/cli/utils/__init__.py @@ -14,6 +14,7 @@ from .datastores import DATASTORE_TYPES, get_datastore_info from .gcp import verify_credentials_and_vertex +from .generation_metadata import metadata_to_cli_args from .logging import handle_cli_error from .template import ( get_available_agents, @@ -35,6 +36,7 @@ "get_template_path", "handle_cli_error", "load_template_config", + "metadata_to_cli_args", "process_template", "prompt_datastore_selection", "prompt_deployment_target", diff --git a/agent_starter_pack/cli/utils/generation_metadata.py b/agent_starter_pack/cli/utils/generation_metadata.py new file mode 100644 index 00000000..1c800814 --- /dev/null +++ b/agent_starter_pack/cli/utils/generation_metadata.py @@ -0,0 +1,50 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for converting project metadata to CLI arguments.""" + +from typing import Any + + +def metadata_to_cli_args(metadata: dict[str, Any]) -> list[str]: + """Convert metadata to CLI arguments for re-creating a project. + + Maps [tool.agent-starter-pack] metadata back to CLI arguments. + Used by upgrade command to re-template old/new versions. + """ + args: list[str] = [] + + if "base_template" in metadata: + args.extend(["--agent", metadata["base_template"]]) + + if "agent_directory" in metadata and metadata["agent_directory"] != "app": + args.extend(["--agent-directory", metadata["agent_directory"]]) + + create_params = metadata.get("create_params", {}) + for key, value in create_params.items(): + if ( + value is None + or value is False + or str(value).lower() == "none" + or value == "" + ): + continue + + arg_name = f"--{key.replace('_', '-')}" + if value is True: + args.append(arg_name) + else: + args.extend([arg_name, str(value)]) + + return args diff --git a/agent_starter_pack/cli/utils/upgrade.py b/agent_starter_pack/cli/utils/upgrade.py new file mode 100644 index 00000000..63d1742d --- /dev/null +++ b/agent_starter_pack/cli/utils/upgrade.py @@ -0,0 +1,501 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""3-way file comparison and dependency merging for upgrade command.""" + +import fnmatch +import hashlib +import logging +import pathlib +import re +import sys +from dataclasses import dataclass, field +from typing import Literal + +if sys.version_info >= (3, 11): + import tomllib +else: + import tomli as tomllib + +# Patterns use {agent_directory} placeholder replaced at runtime +FILE_CATEGORIES = { + "agent_code": [ # Never modified + "{agent_directory}/agent.py", + "{agent_directory}/tools/**/*.py", + "{agent_directory}/prompts/**/*.py", + ], + "config_files": [ # Never overwritten + "deployment/vars/*.tfvars", + ".env", + "*.env", + ], + "dependencies": [ # Special merge handling + "pyproject.toml", + ], + # Everything else is "scaffolding" (3-way compare) +} + + +# Preserve type literals for type-safe reason matching +PreserveType = Literal["asp_unchanged", "already_current", "unchanged_both", None] + + +@dataclass +class FileCompareResult: + """Result of comparing a file across three versions.""" + + path: str + category: str + action: Literal["auto_update", "preserve", "skip", "conflict", "new", "removed"] + reason: str + # For preserve actions, indicates why preserved + preserve_type: PreserveType = None + # For conflicts, store the content hashes + current_hash: str | None = None + old_template_hash: str | None = None + new_template_hash: str | None = None + + +@dataclass +class DependencyChange: + """A single dependency change.""" + + name: str + change_type: Literal["updated", "added", "removed", "kept"] + old_version: str | None = None + new_version: str | None = None + + +@dataclass +class DependencyMergeResult: + """Result of merging dependencies.""" + + changes: list[DependencyChange] = field(default_factory=list) + merged_deps: list[str] = field(default_factory=list) + has_conflicts: bool = False + + +def _expand_patterns(patterns: list[str], agent_directory: str) -> list[str]: + """Expand {agent_directory} placeholder in patterns.""" + return [p.replace("{agent_directory}", agent_directory) for p in patterns] + + +def _matches_any_pattern(path: str, patterns: list[str]) -> bool: + """Check if path matches any glob pattern, including ** recursive patterns.""" + path = path.replace("\\", "/") + + for pattern in patterns: + pattern = pattern.replace("\\", "/") + + if fnmatch.fnmatch(path, pattern): + return True + + if "**" in pattern: + regex = re.escape(pattern) + regex = regex.replace(r"\*\*/", "(?:.*/)?") # **/ = zero or more dirs + regex = regex.replace(r"\*\*", ".*") + regex = regex.replace(r"\*", "[^/]*") + if re.match(f"^{regex}$", path): + return True + + return False + + +def categorize_file(path: str, agent_directory: str = "app") -> str: + """Return category: agent_code, config_files, dependencies, or scaffolding.""" + for category, patterns in FILE_CATEGORIES.items(): + expanded = _expand_patterns(patterns, agent_directory) + if _matches_any_pattern(path, expanded): + return category + return "scaffolding" + + +def _file_hash(file_path: pathlib.Path) -> str | None: + """Calculate SHA256 hash of a file's contents.""" + if not file_path.exists(): + return None + try: + content = file_path.read_bytes() + return hashlib.sha256(content).hexdigest() + except Exception: + return None + + +def three_way_compare( + relative_path: str, + project_dir: pathlib.Path, + old_template_dir: pathlib.Path, + new_template_dir: pathlib.Path, + agent_directory: str = "app", +) -> FileCompareResult: + """Compare file across current, old template, and new template. + + Returns action based on: + - current == old -> auto-update (user didn't modify) + - old == new -> preserve (ASP didn't change) + - all differ -> conflict + """ + category = categorize_file(relative_path, agent_directory) + + if category == "agent_code": + return FileCompareResult( + path=relative_path, + category=category, + action="skip", + reason="Agent code (never modified by upgrade)", + ) + + if category == "config_files": + return FileCompareResult( + path=relative_path, + category=category, + action="skip", + reason="Config file (user's environment settings)", + ) + + if category == "dependencies": + return FileCompareResult( + path=relative_path, + category=category, + action="preserve", + reason="Dependencies (requires merge handling)", + ) + + current_file = project_dir / relative_path + old_template_file = old_template_dir / relative_path + new_template_file = new_template_dir / relative_path + + current_hash = _file_hash(current_file) + old_hash = _file_hash(old_template_file) + new_hash = _file_hash(new_template_file) + + # New file in ASP + if current_hash is None and old_hash is None and new_hash is not None: + return FileCompareResult( + path=relative_path, + category=category, + action="new", + reason="New file in ASP", + new_template_hash=new_hash, + ) + + # File removed in new template + if current_hash is not None and old_hash is not None and new_hash is None: + if current_hash == old_hash: + return FileCompareResult( + path=relative_path, + category=category, + action="removed", + reason="File removed in ASP (you didn't modify it)", + current_hash=current_hash, + old_template_hash=old_hash, + ) + return FileCompareResult( + path=relative_path, + category=category, + action="conflict", + reason="File removed in ASP but you modified it", + current_hash=current_hash, + old_template_hash=old_hash, + ) + + # File doesn't exist anywhere relevant + if current_hash is None and new_hash is None: + return FileCompareResult( + path=relative_path, + category=category, + action="skip", + reason="File not present", + ) + + # User didn't modify (current == old) + if current_hash == old_hash and new_hash is not None: + if old_hash == new_hash: + return FileCompareResult( + path=relative_path, + category=category, + action="preserve", + reason="Unchanged in both project and ASP", + preserve_type="unchanged_both", + current_hash=current_hash, + old_template_hash=old_hash, + new_template_hash=new_hash, + ) + return FileCompareResult( + path=relative_path, + category=category, + action="auto_update", + reason="You didn't modify this file", + current_hash=current_hash, + old_template_hash=old_hash, + new_template_hash=new_hash, + ) + + # ASP didn't change (old == new) + if old_hash == new_hash and current_hash is not None: + return FileCompareResult( + path=relative_path, + category=category, + action="preserve", + reason="ASP didn't change this file", + preserve_type="asp_unchanged", + current_hash=current_hash, + old_template_hash=old_hash, + new_template_hash=new_hash, + ) + + # Already up to date (current == new) + if current_hash == new_hash: + return FileCompareResult( + path=relative_path, + category=category, + action="preserve", + reason="Already up to date", + preserve_type="already_current", + current_hash=current_hash, + old_template_hash=old_hash, + new_template_hash=new_hash, + ) + + # All three differ -> conflict + return FileCompareResult( + path=relative_path, + category=category, + action="conflict", + reason="Both you and ASP modified this file", + current_hash=current_hash, + old_template_hash=old_hash, + new_template_hash=new_hash, + ) + + +def collect_all_files( + project_dir: pathlib.Path, + old_template_dir: pathlib.Path, + new_template_dir: pathlib.Path, + exclude_patterns: list[str] | None = None, +) -> set[str]: + """Collect all unique relative file paths from all three directories.""" + if exclude_patterns is None: + exclude_patterns = [ + ".git/**", + ".venv/**", + "venv/**", + "__pycache__/**", + "*.pyc", + ".DS_Store", + "*.egg-info/**", + "uv.lock", + ".uv/**", + ] + + all_files: set[str] = set() + + for base_dir in [project_dir, old_template_dir, new_template_dir]: + if not base_dir.exists(): + continue + for file_path in base_dir.rglob("*"): + if file_path.is_file(): + relative = str(file_path.relative_to(base_dir)) + # Check exclusions using _matches_any_pattern for ** support + if not _matches_any_pattern(relative, exclude_patterns): + all_files.add(relative) + + return all_files + + +def _parse_dependency(dep_str: str) -> tuple[str, str]: + """Parse a dependency string into (name, version_spec). + + Examples: + "google-adk>=0.2.0" -> ("google-adk", ">=0.2.0") + "requests==2.31.0" -> ("requests", "==2.31.0") + "pytest" -> ("pytest", "") + """ + # Match package name followed by optional version spec + match = re.match(r"^([a-zA-Z0-9_-]+(?:\[[^\]]+\])?)(.*)", dep_str.strip()) + if match: + return match.group(1).lower(), match.group(2).strip() + return dep_str.lower(), "" + + +def _load_dependencies_from_pyproject( + pyproject_path: pathlib.Path, +) -> dict[str, str]: + """Load dependencies as {name: version_spec} dict.""" + if not pyproject_path.exists(): + return {} + + try: + with open(pyproject_path, "rb") as f: + data = tomllib.load(f) + + deps = data.get("project", {}).get("dependencies", []) + result = {} + for dep in deps: + name, version = _parse_dependency(dep) + result[name] = version + return result + except Exception as e: + logging.warning(f"Error loading dependencies from {pyproject_path}: {e}") + return {} + + +def merge_pyproject_dependencies( + current_pyproject: pathlib.Path, + old_template_pyproject: pathlib.Path, + new_template_pyproject: pathlib.Path, +) -> DependencyMergeResult: + """Merge deps: new_template + user_added, where user_added = current - old.""" + current_deps = _load_dependencies_from_pyproject(current_pyproject) + old_deps = _load_dependencies_from_pyproject(old_template_pyproject) + new_deps = _load_dependencies_from_pyproject(new_template_pyproject) + + changes: list[DependencyChange] = [] + merged: dict[str, str] = {} + user_added = set(current_deps.keys()) - set(old_deps.keys()) + asp_managed = set(old_deps.keys()) + + for name, new_version in new_deps.items(): + merged[name] = new_version + + if name in old_deps: + old_version = old_deps[name] + if old_version != new_version: + changes.append( + DependencyChange( + name=name, + change_type="updated", + old_version=old_version, + new_version=new_version, + ) + ) + else: + changes.append( + DependencyChange( + name=name, + change_type="added", + new_version=new_version, + ) + ) + + for name in user_added: + user_version = current_deps[name] + merged[name] = user_version + changes.append( + DependencyChange( + name=name, + change_type="kept", + old_version=user_version, + new_version=user_version, + ) + ) + + for name in asp_managed: + if name not in new_deps and name not in user_added: + changes.append( + DependencyChange( + name=name, + change_type="removed", + old_version=old_deps[name], + ) + ) + + merged_list = [f"{name}{version}" for name, version in sorted(merged.items())] + + return DependencyMergeResult( + changes=changes, + merged_deps=merged_list, + has_conflicts=False, + ) + + +def write_merged_dependencies( + pyproject_path: pathlib.Path, + merged_deps: list[str], +) -> bool: + """Write merged dependencies back to pyproject.toml. + + Args: + pyproject_path: Path to pyproject.toml + merged_deps: List of dependency strings to write + + Returns: + True if successful, False otherwise + """ + if not pyproject_path.exists(): + return False + + try: + content = pyproject_path.read_text(encoding="utf-8") + + # Format dependencies as a TOML array + if merged_deps: + deps_formatted = ",\n ".join(f'"{dep}"' for dep in merged_deps) + new_deps_section = f"dependencies = [\n {deps_formatted},\n]" + else: + new_deps_section = "dependencies = []" + + # Replace the dependencies array using regex + # Match: dependencies = [...] (potentially multiline) + pattern = r"dependencies\s*=\s*\[[^\]]*\]" + content = re.sub(pattern, new_deps_section, content, flags=re.DOTALL) + + pyproject_path.write_text(content, encoding="utf-8") + return True + except Exception as e: + logging.warning(f"Could not write dependencies to {pyproject_path}: {e}") + return False + + +def compare_all_files( + project_dir: pathlib.Path, + old_template_dir: pathlib.Path, + new_template_dir: pathlib.Path, + agent_directory: str = "app", +) -> list[FileCompareResult]: + """Compare all files using 3-way comparison.""" + all_files = collect_all_files(project_dir, old_template_dir, new_template_dir) + + results = [] + for relative_path in sorted(all_files): + result = three_way_compare( + relative_path, + project_dir, + old_template_dir, + new_template_dir, + agent_directory, + ) + results.append(result) + + return results + + +def group_results_by_action( + results: list[FileCompareResult], +) -> dict[str, list[FileCompareResult]]: + """Group results by action type.""" + groups: dict[str, list[FileCompareResult]] = { + "auto_update": [], + "preserve": [], + "skip": [], + "conflict": [], + "new": [], + "removed": [], + } + + for result in results: + groups[result.action].append(result) + + return groups diff --git a/tests/cli/commands/test_upgrade.py b/tests/cli/commands/test_upgrade.py new file mode 100644 index 00000000..85ec1f74 --- /dev/null +++ b/tests/cli/commands/test_upgrade.py @@ -0,0 +1,377 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for upgrade command.""" + +import pathlib +import re +from unittest.mock import patch + +import pytest +from click.testing import CliRunner + +from agent_starter_pack.cli.commands.upgrade import upgrade + + +def strip_ansi(text: str) -> str: + """Remove ANSI escape codes from text.""" + ansi_pattern = re.compile(r"\x1b\[[0-9;]*m") + return ansi_pattern.sub("", text) + + +class TestUpgradeErrorCases: + """Test error handling in upgrade command.""" + + def test_missing_asp_metadata(self, tmp_path: pathlib.Path) -> None: + """Test error when pyproject.toml has no ASP metadata.""" + # Create pyproject.toml without [tool.agent-starter-pack] section + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + """ +[project] +name = "test-project" +version = "0.1.0" +""" + ) + + runner = CliRunner() + result = runner.invoke(upgrade, [str(tmp_path)]) + output = strip_ansi(result.output) + + assert result.exit_code == 1 + assert "No agent-starter-pack metadata found" in output + assert "[tool.agent-starter-pack]" in output + + def test_missing_asp_version(self, tmp_path: pathlib.Path) -> None: + """Test error when metadata exists but asp_version is missing.""" + # Create pyproject.toml with ASP metadata but no asp_version + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + """ +[project] +name = "test-project" +version = "0.1.0" + +[tool.agent-starter-pack] +name = "test-project" +base_template = "adk" +""" + ) + + runner = CliRunner() + result = runner.invoke(upgrade, [str(tmp_path)]) + output = strip_ansi(result.output) + + assert result.exit_code == 1 + assert "No asp_version found" in output + + @patch("agent_starter_pack.cli.commands.upgrade.get_current_version") + def test_already_at_latest_version( + self, mock_version, tmp_path: pathlib.Path + ) -> None: + """Test message when project is already at latest version.""" + mock_version.return_value = "0.31.0" + + # Create pyproject.toml with matching version + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + """ +[project] +name = "test-project" +version = "0.1.0" + +[tool.agent-starter-pack] +name = "test-project" +base_template = "adk" +asp_version = "0.31.0" +""" + ) + + runner = CliRunner() + result = runner.invoke(upgrade, [str(tmp_path)]) + output = strip_ansi(result.output) + + assert result.exit_code == 0 + assert "already at version 0.31.0" in output + + @patch("agent_starter_pack.cli.commands.upgrade._ensure_uvx_available") + @patch("agent_starter_pack.cli.commands.upgrade.get_current_version") + def test_uvx_not_available( + self, mock_version, mock_uvx, tmp_path: pathlib.Path + ) -> None: + """Test error when uvx is not installed.""" + mock_version.return_value = "0.31.0" + mock_uvx.return_value = False + + # Create pyproject.toml with older version + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + """ +[project] +name = "test-project" +version = "0.1.0" + +[tool.agent-starter-pack] +name = "test-project" +base_template = "adk" +asp_version = "0.30.0" +""" + ) + + runner = CliRunner() + result = runner.invoke(upgrade, [str(tmp_path)]) + output = strip_ansi(result.output) + + assert result.exit_code == 1 + assert "uvx" in output + assert "required" in output.lower() + + +class TestUpgradeDryRun: + """Test dry-run mode.""" + + @patch("agent_starter_pack.cli.commands.upgrade._ensure_uvx_available") + @patch("agent_starter_pack.cli.commands.upgrade._run_create_command") + @patch("agent_starter_pack.cli.commands.upgrade.get_current_version") + def test_dry_run_no_changes_applied( + self, mock_version, mock_create, mock_uvx, tmp_path: pathlib.Path + ) -> None: + """Test that dry-run doesn't modify files.""" + mock_version.return_value = "0.31.0" + mock_uvx.return_value = True + + # Mock template creation to create minimal template structure + def create_template(_args, output_dir, project_name, _version=None): + del _args, _version # Unused + template_dir = output_dir / project_name + template_dir.mkdir(parents=True) + (template_dir / "pyproject.toml").write_text( + """ +[project] +name = "test-project" +dependencies = [] +""" + ) + (template_dir / "README.md").write_text("# Test") + return True + + mock_create.side_effect = create_template + + # Create project with older version + pyproject = tmp_path / "pyproject.toml" + original_content = """ +[project] +name = "test-project" +version = "0.1.0" +dependencies = [] + +[tool.agent-starter-pack] +name = "test-project" +base_template = "adk" +asp_version = "0.30.0" +""" + pyproject.write_text(original_content) + + runner = CliRunner() + result = runner.invoke(upgrade, [str(tmp_path), "--dry-run"]) + output = strip_ansi(result.output) + + assert result.exit_code == 0 + assert "Dry run complete" in output + # Verify file wasn't modified + assert pyproject.read_text() == original_content + + +class TestUpgradeE2E: + """End-to-end tests for upgrade command.""" + + @patch("agent_starter_pack.cli.commands.upgrade._ensure_uvx_available") + @patch("agent_starter_pack.cli.commands.upgrade._run_create_command") + @patch("agent_starter_pack.cli.commands.upgrade.get_current_version") + def test_auto_update_unchanged_files( + self, mock_version, mock_create, mock_uvx, tmp_path: pathlib.Path + ) -> None: + """Test that unchanged files are auto-updated.""" + mock_version.return_value = "0.31.0" + mock_uvx.return_value = True + + def create_template(_args, output_dir, project_name, version=None): + del _args # Unused + template_dir = output_dir / project_name + template_dir.mkdir(parents=True) + (template_dir / "pyproject.toml").write_text( + '[project]\nname = "test"\ndependencies = []' + ) + if version == "0.30.0": + # Old template + (template_dir / "Makefile").write_text("# Old Makefile") + else: + # New template + (template_dir / "Makefile").write_text("# New Makefile with updates") + return True + + mock_create.side_effect = create_template + + # Create project with file matching old template (user didn't modify) + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + '[project]\nname = "test"\ndependencies = []\n\n' + '[tool.agent-starter-pack]\nname = "test"\n' + 'base_template = "adk"\nasp_version = "0.30.0"' + ) + makefile = tmp_path / "Makefile" + makefile.write_text("# Old Makefile") # Same as old template + + runner = CliRunner() + result = runner.invoke(upgrade, [str(tmp_path), "--auto-approve"]) + output = strip_ansi(result.output) + + assert result.exit_code == 0 + assert "Upgrade complete" in output + # Verify file was updated + assert "New Makefile with updates" in makefile.read_text() + + @patch("agent_starter_pack.cli.commands.upgrade._ensure_uvx_available") + @patch("agent_starter_pack.cli.commands.upgrade._run_create_command") + @patch("agent_starter_pack.cli.commands.upgrade.get_current_version") + def test_preserve_user_modified_files_when_asp_unchanged( + self, mock_version, mock_create, mock_uvx, tmp_path: pathlib.Path + ) -> None: + """Test that user-modified files are preserved when ASP didn't change them.""" + mock_version.return_value = "0.31.0" + mock_uvx.return_value = True + + def create_template(_args, output_dir, project_name, _version=None): + del _args, _version # Unused + template_dir = output_dir / project_name + template_dir.mkdir(parents=True) + (template_dir / "pyproject.toml").write_text( + '[project]\nname = "test"\ndependencies = []' + ) + # Same content in old and new template + (template_dir / "Makefile").write_text("# Template Makefile") + return True + + mock_create.side_effect = create_template + + # Create project with user-modified file + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + '[project]\nname = "test"\ndependencies = []\n\n' + '[tool.agent-starter-pack]\nname = "test"\n' + 'base_template = "adk"\nasp_version = "0.30.0"' + ) + makefile = tmp_path / "Makefile" + makefile.write_text("# My custom Makefile") # User modified + + runner = CliRunner() + result = runner.invoke(upgrade, [str(tmp_path), "--auto-approve"]) + output = strip_ansi(result.output) + + assert result.exit_code == 0 + # Verify user's file was preserved + assert "My custom Makefile" in makefile.read_text() + assert "Preserving" in output or "preserve" in output.lower() + + @patch("agent_starter_pack.cli.commands.upgrade._ensure_uvx_available") + @patch("agent_starter_pack.cli.commands.upgrade._run_create_command") + @patch("agent_starter_pack.cli.commands.upgrade.get_current_version") + def test_detects_conflict_when_both_changed( + self, mock_version, mock_create, mock_uvx, tmp_path: pathlib.Path + ) -> None: + """Test that conflicts are detected when both user and ASP changed a file.""" + mock_version.return_value = "0.31.0" + mock_uvx.return_value = True + + def create_template(_args, output_dir, project_name, version=None): + del _args # Unused + template_dir = output_dir / project_name + template_dir.mkdir(parents=True) + (template_dir / "pyproject.toml").write_text( + '[project]\nname = "test"\ndependencies = []' + ) + if version == "0.30.0": + (template_dir / "Makefile").write_text("# Old template") + else: + (template_dir / "Makefile").write_text("# New template") + return True + + mock_create.side_effect = create_template + + # Create project with user-modified file (different from both templates) + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + '[project]\nname = "test"\ndependencies = []\n\n' + '[tool.agent-starter-pack]\nname = "test"\n' + 'base_template = "adk"\nasp_version = "0.30.0"' + ) + makefile = tmp_path / "Makefile" + makefile.write_text("# User modified") # Different from both templates + + runner = CliRunner() + result = runner.invoke(upgrade, [str(tmp_path), "--auto-approve"]) + output = strip_ansi(result.output) + + assert result.exit_code == 0 + assert "Conflict" in output + # With auto-approve, user's version is kept + assert "User modified" in makefile.read_text() + + @patch("agent_starter_pack.cli.commands.upgrade._ensure_uvx_available") + @patch("agent_starter_pack.cli.commands.upgrade._run_create_command") + @patch("agent_starter_pack.cli.commands.upgrade.get_current_version") + def test_skips_agent_code( + self, mock_version, mock_create, mock_uvx, tmp_path: pathlib.Path + ) -> None: + """Test that agent code files are never modified.""" + mock_version.return_value = "0.31.0" + mock_uvx.return_value = True + + def create_template(_args, output_dir, project_name, _version=None): + del _args, _version # Unused + template_dir = output_dir / project_name + template_dir.mkdir(parents=True) + (template_dir / "pyproject.toml").write_text( + '[project]\nname = "test"\ndependencies = []' + ) + (template_dir / "app").mkdir() + (template_dir / "app/agent.py").write_text("# Template agent") + return True + + mock_create.side_effect = create_template + + # Create project with agent code + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + '[project]\nname = "test"\ndependencies = []\n\n' + '[tool.agent-starter-pack]\nname = "test"\n' + 'base_template = "adk"\nasp_version = "0.30.0"' + ) + app_dir = tmp_path / "app" + app_dir.mkdir() + agent_file = app_dir / "agent.py" + agent_file.write_text("# My custom agent code") + + runner = CliRunner() + result = runner.invoke(upgrade, [str(tmp_path), "--auto-approve"]) + output = strip_ansi(result.output) + + assert result.exit_code == 0 + assert "Skipping" in output + # Verify agent code was NOT modified + assert "My custom agent code" in agent_file.read_text() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/cli/utils/test_generation_metadata.py b/tests/cli/utils/test_generation_metadata.py index 13db6d8a..27ce0de6 100644 --- a/tests/cli/utils/test_generation_metadata.py +++ b/tests/cli/utils/test_generation_metadata.py @@ -34,6 +34,8 @@ else: import tomli as tomllib +from agent_starter_pack.cli.utils.generation_metadata import metadata_to_cli_args + def load_asp_metadata(pyproject_path: pathlib.Path) -> dict[str, Any]: """Load agent-starter-pack metadata from pyproject.toml. @@ -585,49 +587,5 @@ def _compare_pyproject_toml( return differences -def metadata_to_cli_args(metadata: dict[str, Any]) -> list[str]: - """Convert metadata dictionary to CLI arguments. - - This function maps the pyproject.toml metadata back to CLI arguments - that could be used to recreate the project. - - Args: - metadata: Dictionary from [tool.agent-starter-pack] section - - Returns: - List of CLI arguments - """ - args: list[str] = [] - - # Required mappings from metadata - if "base_template" in metadata: - args.extend(["--agent", metadata["base_template"]]) - - if "agent_directory" in metadata and metadata["agent_directory"] != "app": - args.extend(["--agent-directory", metadata["agent_directory"]]) - - # Get create_params for the rest - create_params = metadata.get("create_params", {}) - - # Add all create_params dynamically - for key, value in create_params.items(): - # Skip None, "none", "None", False, and empty values - if ( - value is None - or value is False - or str(value).lower() == "none" - or value == "" - ): - continue - - arg_name = f"--{key.replace('_', '-')}" - if value is True: - args.append(arg_name) - else: - args.extend([arg_name, str(value)]) - - return args - - if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/cli/utils/test_upgrade_utils.py b/tests/cli/utils/test_upgrade_utils.py new file mode 100644 index 00000000..c5391947 --- /dev/null +++ b/tests/cli/utils/test_upgrade_utils.py @@ -0,0 +1,490 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for upgrade utilities.""" + +import pathlib +import tempfile + +import pytest + +from agent_starter_pack.cli.utils.upgrade import ( + FileCompareResult, + categorize_file, + collect_all_files, + group_results_by_action, + merge_pyproject_dependencies, + three_way_compare, + write_merged_dependencies, +) + + +class TestCategorizeFile: + """Tests for file categorization.""" + + def test_agent_code_patterns(self) -> None: + """Test that agent code files are correctly categorized.""" + assert categorize_file("app/agent.py") == "agent_code" + assert categorize_file("app/tools/search.py") == "agent_code" + assert categorize_file("app/prompts/main.py") == "agent_code" + + def test_config_files(self) -> None: + """Test that config files are correctly categorized.""" + assert categorize_file("deployment/vars/dev.tfvars") == "config_files" + assert categorize_file(".env") == "config_files" + + def test_dependencies(self) -> None: + """Test that pyproject.toml is categorized as dependencies.""" + assert categorize_file("pyproject.toml") == "dependencies" + + def test_scaffolding_files(self) -> None: + """Test that scaffolding files are correctly categorized.""" + assert categorize_file("deployment/terraform/main.tf") == "scaffolding" + assert categorize_file(".github/workflows/deploy.yaml") == "scaffolding" + assert categorize_file("Makefile") == "scaffolding" + assert categorize_file("tests/conftest.py") == "scaffolding" + + def test_custom_agent_directory(self) -> None: + """Test categorization with custom agent directory.""" + assert categorize_file("my_agent/agent.py", "my_agent") == "agent_code" + assert categorize_file("my_agent/tools/custom.py", "my_agent") == "agent_code" + # Default app directory should not match + assert categorize_file("app/agent.py", "my_agent") == "scaffolding" + + +class TestThreeWayCompare: + """Tests for 3-way file comparison.""" + + @pytest.fixture + def temp_dirs(self) -> tuple[pathlib.Path, pathlib.Path, pathlib.Path]: + """Create temporary directories for testing.""" + with tempfile.TemporaryDirectory() as temp: + project = pathlib.Path(temp) / "project" + old_template = pathlib.Path(temp) / "old" + new_template = pathlib.Path(temp) / "new" + + project.mkdir() + old_template.mkdir() + new_template.mkdir() + + yield project, old_template, new_template + + def test_auto_update_unchanged_by_user( + self, temp_dirs: tuple[pathlib.Path, pathlib.Path, pathlib.Path] + ) -> None: + """Test auto-update when user didn't modify file.""" + project, old_template, new_template = temp_dirs + + # Create same file in project and old template + (project / "test.txt").write_text("old content") + (old_template / "test.txt").write_text("old content") + # New template has updated content + (new_template / "test.txt").write_text("new content") + + result = three_way_compare("test.txt", project, old_template, new_template) + + assert result.action == "auto_update" + assert "didn't modify" in result.reason.lower() + + def test_preserve_asp_unchanged( + self, temp_dirs: tuple[pathlib.Path, pathlib.Path, pathlib.Path] + ) -> None: + """Test preserve when ASP didn't change file.""" + project, old_template, new_template = temp_dirs + + # User modified file + (project / "test.txt").write_text("user modified") + # Old and new template have same content + (old_template / "test.txt").write_text("template content") + (new_template / "test.txt").write_text("template content") + + result = three_way_compare("test.txt", project, old_template, new_template) + + assert result.action == "preserve" + assert "asp didn't change" in result.reason.lower() + + def test_conflict_both_changed( + self, temp_dirs: tuple[pathlib.Path, pathlib.Path, pathlib.Path] + ) -> None: + """Test conflict when both user and ASP changed file.""" + project, old_template, new_template = temp_dirs + + # All three have different content + (project / "test.txt").write_text("user content") + (old_template / "test.txt").write_text("old content") + (new_template / "test.txt").write_text("new content") + + result = three_way_compare("test.txt", project, old_template, new_template) + + assert result.action == "conflict" + assert "both" in result.reason.lower() + + def test_new_file_in_asp( + self, temp_dirs: tuple[pathlib.Path, pathlib.Path, pathlib.Path] + ) -> None: + """Test detection of new file in ASP.""" + project, old_template, new_template = temp_dirs + + # Only exists in new template + (new_template / "new_file.txt").write_text("new content") + + result = three_way_compare("new_file.txt", project, old_template, new_template) + + assert result.action == "new" + assert "new file" in result.reason.lower() + + def test_removed_file_in_asp_user_unchanged( + self, temp_dirs: tuple[pathlib.Path, pathlib.Path, pathlib.Path] + ) -> None: + """Test detection of removed file when user didn't modify.""" + project, old_template, new_template = temp_dirs + + # File exists in project and old template, not in new + (project / "old_file.txt").write_text("same content") + (old_template / "old_file.txt").write_text("same content") + + result = three_way_compare("old_file.txt", project, old_template, new_template) + + assert result.action == "removed" + + def test_removed_file_in_asp_user_modified( + self, temp_dirs: tuple[pathlib.Path, pathlib.Path, pathlib.Path] + ) -> None: + """Test conflict when removed file was modified by user.""" + project, old_template, new_template = temp_dirs + + # User modified a file that was removed in new template + (project / "old_file.txt").write_text("user modified") + (old_template / "old_file.txt").write_text("original content") + + result = three_way_compare("old_file.txt", project, old_template, new_template) + + assert result.action == "conflict" + + def test_skip_agent_code( + self, temp_dirs: tuple[pathlib.Path, pathlib.Path, pathlib.Path] + ) -> None: + """Test that agent code is always skipped.""" + project, old_template, new_template = temp_dirs + + (project / "app").mkdir() + (project / "app/agent.py").write_text("user agent") + (old_template / "app").mkdir() + (old_template / "app/agent.py").write_text("old agent") + (new_template / "app").mkdir() + (new_template / "app/agent.py").write_text("new agent") + + result = three_way_compare("app/agent.py", project, old_template, new_template) + + assert result.action == "skip" + assert result.category == "agent_code" + + def test_skip_config_files( + self, temp_dirs: tuple[pathlib.Path, pathlib.Path, pathlib.Path] + ) -> None: + """Test that config files are always skipped.""" + project, old_template, new_template = temp_dirs + + (project / ".env").write_text("SECRET=user") + (old_template / ".env").write_text("SECRET=old") + (new_template / ".env").write_text("SECRET=new") + + result = three_way_compare(".env", project, old_template, new_template) + + assert result.action == "skip" + assert result.category == "config_files" + + +class TestCollectAllFiles: + """Tests for collecting files from directories.""" + + def test_collects_from_all_dirs(self) -> None: + """Test that files are collected from all three directories.""" + with tempfile.TemporaryDirectory() as temp: + project = pathlib.Path(temp) / "project" + old_template = pathlib.Path(temp) / "old" + new_template = pathlib.Path(temp) / "new" + + project.mkdir() + old_template.mkdir() + new_template.mkdir() + + (project / "project_file.txt").write_text("content") + (old_template / "old_file.txt").write_text("content") + (new_template / "new_file.txt").write_text("content") + + files = collect_all_files(project, old_template, new_template) + + assert "project_file.txt" in files + assert "old_file.txt" in files + assert "new_file.txt" in files + + def test_excludes_patterns(self) -> None: + """Test that excluded patterns are not collected.""" + with tempfile.TemporaryDirectory() as temp: + project = pathlib.Path(temp) / "project" + project.mkdir() + + (project / ".git").mkdir() + (project / ".git/config").write_text("content") + (project / "real_file.txt").write_text("content") + + files = collect_all_files( + project, project, project, exclude_patterns=[".git/**"] + ) + + assert ".git/config" not in files + assert "real_file.txt" in files + + +class TestGroupResultsByAction: + """Tests for grouping results by action.""" + + def test_groups_correctly(self) -> None: + """Test that results are grouped by action.""" + results = [ + FileCompareResult("file1.txt", "scaffolding", "auto_update", "reason"), + FileCompareResult("file2.txt", "scaffolding", "preserve", "reason"), + FileCompareResult("file3.txt", "scaffolding", "conflict", "reason"), + FileCompareResult("file4.txt", "scaffolding", "auto_update", "reason"), + ] + + groups = group_results_by_action(results) + + assert len(groups["auto_update"]) == 2 + assert len(groups["preserve"]) == 1 + assert len(groups["conflict"]) == 1 + assert len(groups["skip"]) == 0 + + +class TestMergePyprojectDependencies: + """Tests for dependency merging.""" + + @pytest.fixture + def pyproject_dirs(self) -> tuple[pathlib.Path, pathlib.Path, pathlib.Path]: + """Create temporary directories with pyproject.toml files.""" + with tempfile.TemporaryDirectory() as temp: + current = pathlib.Path(temp) / "current" + old_template = pathlib.Path(temp) / "old" + new_template = pathlib.Path(temp) / "new" + + current.mkdir() + old_template.mkdir() + new_template.mkdir() + + yield current, old_template, new_template + + def _write_pyproject(self, path: pathlib.Path, deps: list[str]) -> None: + """Write a pyproject.toml with the given dependencies.""" + deps_str = ", ".join(f'"{d}"' for d in deps) + content = f""" +[project] +name = "test-project" +dependencies = [{deps_str}] +""" + (path / "pyproject.toml").write_text(content) + + def test_updated_dependency( + self, pyproject_dirs: tuple[pathlib.Path, pathlib.Path, pathlib.Path] + ) -> None: + """Test detection of updated dependency.""" + current, old_template, new_template = pyproject_dirs + + self._write_pyproject(current, ["google-adk>=0.2.0"]) + self._write_pyproject(old_template, ["google-adk>=0.2.0"]) + self._write_pyproject(new_template, ["google-adk>=0.3.0"]) + + result = merge_pyproject_dependencies( + current / "pyproject.toml", + old_template / "pyproject.toml", + new_template / "pyproject.toml", + ) + + updated = [c for c in result.changes if c.change_type == "updated"] + assert len(updated) == 1 + assert updated[0].name == "google-adk" + + def test_user_added_dependency_kept( + self, pyproject_dirs: tuple[pathlib.Path, pathlib.Path, pathlib.Path] + ) -> None: + """Test that user-added dependencies are preserved.""" + current, old_template, new_template = pyproject_dirs + + self._write_pyproject(current, ["google-adk>=0.2.0", "my-custom-lib>=1.0.0"]) + self._write_pyproject(old_template, ["google-adk>=0.2.0"]) + self._write_pyproject(new_template, ["google-adk>=0.3.0"]) + + result = merge_pyproject_dependencies( + current / "pyproject.toml", + old_template / "pyproject.toml", + new_template / "pyproject.toml", + ) + + kept = [c for c in result.changes if c.change_type == "kept"] + assert len(kept) == 1 + assert kept[0].name == "my-custom-lib" + + # Check merged deps contains both + merged_names = [d.split(">")[0].split("=")[0] for d in result.merged_deps] + assert "google-adk" in merged_names + assert "my-custom-lib" in merged_names + + def test_new_asp_dependency( + self, pyproject_dirs: tuple[pathlib.Path, pathlib.Path, pathlib.Path] + ) -> None: + """Test detection of new ASP dependency.""" + current, old_template, new_template = pyproject_dirs + + self._write_pyproject(current, ["google-adk>=0.2.0"]) + self._write_pyproject(old_template, ["google-adk>=0.2.0"]) + self._write_pyproject(new_template, ["google-adk>=0.3.0", "new-dep>=1.0.0"]) + + result = merge_pyproject_dependencies( + current / "pyproject.toml", + old_template / "pyproject.toml", + new_template / "pyproject.toml", + ) + + added = [c for c in result.changes if c.change_type == "added"] + assert len(added) == 1 + assert added[0].name == "new-dep" + + def test_removed_asp_dependency( + self, pyproject_dirs: tuple[pathlib.Path, pathlib.Path, pathlib.Path] + ) -> None: + """Test detection of removed ASP dependency.""" + current, old_template, new_template = pyproject_dirs + + self._write_pyproject(current, ["google-adk>=0.2.0", "old-dep>=1.0.0"]) + self._write_pyproject(old_template, ["google-adk>=0.2.0", "old-dep>=1.0.0"]) + self._write_pyproject(new_template, ["google-adk>=0.3.0"]) + + result = merge_pyproject_dependencies( + current / "pyproject.toml", + old_template / "pyproject.toml", + new_template / "pyproject.toml", + ) + + removed = [c for c in result.changes if c.change_type == "removed"] + assert len(removed) == 1 + assert removed[0].name == "old-dep" + + +class TestWriteMergedDependencies: + """Tests for writing merged dependencies back to pyproject.toml.""" + + def test_writes_dependencies_correctly(self, tmp_path: pathlib.Path) -> None: + """Test that dependencies are written correctly to pyproject.toml.""" + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + """ +[project] +name = "test-project" +dependencies = [ + "old-dep>=1.0.0", +] +""" + ) + + merged_deps = ["google-adk>=0.3.0", "my-custom-lib>=1.0.0"] + result = write_merged_dependencies(pyproject, merged_deps) + + assert result is True + content = pyproject.read_text() + assert '"google-adk>=0.3.0"' in content + assert '"my-custom-lib>=1.0.0"' in content + assert "old-dep" not in content + + def test_handles_empty_dependencies(self, tmp_path: pathlib.Path) -> None: + """Test writing empty dependencies list.""" + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + """ +[project] +name = "test-project" +dependencies = ["some-dep>=1.0.0"] +""" + ) + + result = write_merged_dependencies(pyproject, []) + + assert result is True + content = pyproject.read_text() + assert "dependencies = []" in content + + def test_returns_false_for_missing_file(self, tmp_path: pathlib.Path) -> None: + """Test that missing file returns False.""" + pyproject = tmp_path / "nonexistent.toml" + + result = write_merged_dependencies(pyproject, ["dep>=1.0.0"]) + + assert result is False + + +class TestCollectAllFilesDeepExclusion: + """Tests for ** glob pattern exclusions in collect_all_files.""" + + def test_excludes_deeply_nested_git_files(self) -> None: + """Test that .git/** excludes deeply nested files.""" + with tempfile.TemporaryDirectory() as temp: + project = pathlib.Path(temp) / "project" + project.mkdir() + + # Create deeply nested .git structure + git_dir = project / ".git" / "objects" / "pack" + git_dir.mkdir(parents=True) + (git_dir / "pack-123.idx").write_text("content") + (project / ".git" / "config").write_text("content") + (project / ".git" / "HEAD").write_text("ref: refs/heads/main") + + # Create a real file + (project / "README.md").write_text("content") + + files = collect_all_files( + project, project, project, exclude_patterns=[".git/**"] + ) + + assert ".git/config" not in files + assert ".git/HEAD" not in files + assert ".git/objects/pack/pack-123.idx" not in files + assert "README.md" in files + + def test_excludes_venv_deeply_nested(self) -> None: + """Test that .venv/** excludes deeply nested virtual env files.""" + with tempfile.TemporaryDirectory() as temp: + project = pathlib.Path(temp) / "project" + project.mkdir() + + # Create nested venv structure + venv_lib = project / ".venv" / "lib" / "python3.12" / "site-packages" + venv_lib.mkdir(parents=True) + (venv_lib / "some_package" / "__init__.py").parent.mkdir() + (venv_lib / "some_package" / "__init__.py").write_text("content") + + # Create a real file + (project / "main.py").write_text("content") + + files = collect_all_files( + project, project, project, exclude_patterns=[".venv/**"] + ) + + assert ( + ".venv/lib/python3.12/site-packages/some_package/__init__.py" + not in files + ) + assert "main.py" in files + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From d693caee2b134a4da9a6006792e24ccf8f6c908c Mon Sep 17 00:00:00 2001 From: Stephen Allen Date: Tue, 20 Jan 2026 14:55:59 -0600 Subject: [PATCH 2/2] fix: address gca comments --- agent_starter_pack/cli/commands/upgrade.py | 12 ++++++++---- agent_starter_pack/cli/utils/upgrade.py | 3 ++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/agent_starter_pack/cli/commands/upgrade.py b/agent_starter_pack/cli/commands/upgrade.py index a3968809..1fa63605 100644 --- a/agent_starter_pack/cli/commands/upgrade.py +++ b/agent_starter_pack/cli/commands/upgrade.py @@ -18,6 +18,7 @@ import logging import pathlib import re +import shlex import shutil import subprocess import tempfile @@ -41,6 +42,9 @@ console = Console() +# Maximum characters to display when showing diffs +MAX_DIFF_DISPLAY_CHARS = 2000 + def _ensure_uvx_available() -> bool: """Check if uvx is available.""" @@ -79,7 +83,7 @@ def _run_create_command( cmd.extend(["--auto-approve", "--skip-deps", "--skip-checks"]) cmd.extend(args) - logging.debug(f"Running command: {' '.join(cmd)}") + logging.debug(f"Running command: {shlex.join(cmd)}") try: result = subprocess.run( @@ -238,9 +242,9 @@ def _handle_conflict( console.print() if diff_output: - # Limit output to ~2000 characters - if len(diff_output) > 2000: - console.print(diff_output[:2000]) + # Limit output to a reasonable length + if len(diff_output) > MAX_DIFF_DISPLAY_CHARS: + console.print(diff_output[:MAX_DIFF_DISPLAY_CHARS]) console.print("[dim]... (truncated)[/dim]") else: console.print(diff_output) diff --git a/agent_starter_pack/cli/utils/upgrade.py b/agent_starter_pack/cli/utils/upgrade.py index 63d1742d..75f4c675 100644 --- a/agent_starter_pack/cli/utils/upgrade.py +++ b/agent_starter_pack/cli/utils/upgrade.py @@ -128,7 +128,8 @@ def _file_hash(file_path: pathlib.Path) -> str | None: try: content = file_path.read_bytes() return hashlib.sha256(content).hexdigest() - except Exception: + except Exception as e: + logging.warning(f"Could not hash file {file_path}: {e}") return None