From 5acd3a87868f1c5fbaa4676619819c3ce2a6e9a9 Mon Sep 17 00:00:00 2001 From: d-padmanabhan <88682350+d-padmanabhan@users.noreply.github.com> Date: Sun, 4 Jan 2026 23:06:50 -0500 Subject: [PATCH 1/4] fix(core): fix critical bugs and code quality issues - Fix security group egress loop incorrectly nested inside ingress loop in vpc.py - Add missing imports (Table, asdict, get_theme_dir) causing runtime errors - Replace bare except clauses with specific exception types - Remove unused imports and fix f-strings without placeholders - Fix ambiguous variable names (l -> lis) in test files - Mark intentionally unused variables with underscore prefix - Mark test fixture AWS example IDs as allowlisted secrets - Make scripts with shebangs executable Fixes security group data processing bug that caused egress rules to be processed N times (once per ingress rule) instead of once. --- docs/command-hierarchy-split.md | 2 +- scripts/automated_issue_resolver.py | 83 ++-- scripts/clean-output.py | 65 +-- scripts/fetch_issues.py | 40 +- scripts/issue_investigator.py | 372 +++++++++++------- scripts/run_issue_tests.py | 36 +- scripts/s2svpn | 212 ++++++---- scripts/shell_runner.py | 123 +++--- src/aws_network_tools/cli/runner.py | 71 ++-- src/aws_network_tools/config/__init__.py | 66 ++-- src/aws_network_tools/core/base.py | 8 +- src/aws_network_tools/core/cache_db.py | 146 ++++--- src/aws_network_tools/core/validators.py | 84 ++-- src/aws_network_tools/modules/anfw.py | 14 +- src/aws_network_tools/modules/cloudwan.py | 2 + src/aws_network_tools/modules/elb.py | 62 +-- src/aws_network_tools/modules/vpc.py | 85 ++-- src/aws_network_tools/shell/base.py | 118 ++++-- src/aws_network_tools/shell/graph.py | 30 +- .../shell/handlers/cloudwan.py | 22 +- src/aws_network_tools/shell/handlers/ec2.py | 24 +- .../shell/handlers/firewall.py | 97 +++-- src/aws_network_tools/shell/handlers/root.py | 203 ++++++---- src/aws_network_tools/shell/handlers/tgw.py | 4 +- .../shell/handlers/utilities.py | 3 +- src/aws_network_tools/shell/handlers/vpc.py | 4 +- src/aws_network_tools/shell/handlers/vpn.py | 9 +- src/aws_network_tools/shell/main.py | 14 +- src/aws_network_tools/themes/__init__.py | 144 +++---- tests/agent_test_runner.py | 137 ++++--- tests/fixtures/client_vpn.py | 4 +- tests/fixtures/cloudwan_connect.py | 4 +- tests/fixtures/command_fixtures.py | 321 ++++++++++----- tests/fixtures/ec2.py | 25 +- tests/fixtures/elb.py | 12 +- tests/fixtures/firewall.py | 20 +- tests/fixtures/fixture_generator.py | 6 +- tests/fixtures/gateways.py | 20 +- tests/fixtures/global_accelerator.py | 12 +- tests/fixtures/global_network.py | 5 +- tests/fixtures/peering.py | 32 +- tests/fixtures/route53_resolver.py | 18 +- tests/fixtures/tgw.py | 14 +- tests/fixtures/vpc.py | 61 ++- tests/fixtures/vpc_endpoints.py | 92 +++-- tests/generate_report.py | 30 +- tests/integration/test_github_issues.py | 56 +-- tests/integration/test_issue_9_10_simple.py | 116 +++--- tests/integration/test_workflows.py | 40 +- tests/interactive_routing_cache_test.py | 17 +- tests/test_cloudwan_branch.py | 143 +++++-- tests/test_cloudwan_handlers.py | 90 +++-- tests/test_cloudwan_issue3.py | 25 +- tests/test_cloudwan_issues.py | 8 +- tests/test_command_graph/base_context_test.py | 30 +- tests/test_command_graph/conftest.py | 226 +++++++---- tests/test_command_graph/test_base_context.py | 97 +++-- .../test_cloudwan_branch.py | 11 +- .../test_context_commands.py | 9 +- .../test_command_graph/test_data_generator.py | 120 +++--- .../test_top_level_commands.py | 13 +- tests/test_ec2_context.py | 57 +-- tests/test_elb_handler.py | 46 ++- tests/test_elb_module.py | 75 +++- tests/test_graph_commands.py | 368 +++++++++++------ tests/test_issue_2_show_detail.py | 15 +- tests/test_issue_5_tgw_rt_details.py | 86 ++-- tests/test_issue_8_vpc_set.py | 50 ++- tests/test_policy_change_events.py | 10 +- tests/test_refresh_command.py | 9 +- tests/test_shell_hierarchy.py | 30 +- tests/test_show_regions.py | 12 +- tests/test_tgw_issues.py | 13 +- tests/test_utils/context_state_manager.py | 8 +- tests/test_utils/data_format_adapter.py | 153 +++---- .../test_utils/test_context_state_manager.py | 43 +- tests/test_utils/test_data_format_adapter.py | 118 +++--- tests/test_validators.py | 19 +- tests/test_vpn_tunnels.py | 4 +- tests/unit/test_ec2_eni_filtering.py | 25 +- tests/unit/test_elb_commands.py | 33 +- themes/catppuccin-latte.json | 28 +- themes/catppuccin-macchiato.json | 28 +- themes/catppuccin-mocha-vibrant.json | 26 +- themes/catppuccin-mocha.json | 28 +- themes/dracula.json | 28 +- validate_issues.sh | 2 +- 87 files changed, 3244 insertions(+), 2027 deletions(-) mode change 100644 => 100755 scripts/automated_issue_resolver.py mode change 100644 => 100755 scripts/fetch_issues.py mode change 100644 => 100755 scripts/issue_investigator.py mode change 100644 => 100755 scripts/run_issue_tests.py mode change 100644 => 100755 scripts/shell_runner.py mode change 100644 => 100755 src/aws_network_tools/cli/runner.py mode change 100644 => 100755 tests/agent_test_runner.py diff --git a/docs/command-hierarchy-split.md b/docs/command-hierarchy-split.md index 0bc3c9a..7a11955 100644 --- a/docs/command-hierarchy-split.md +++ b/docs/command-hierarchy-split.md @@ -167,7 +167,7 @@ graph LR firewall_context --> firewall_show_networking firewall_set_rule_group{"set rule-group → rule-group"}:::set firewall_context --> firewall_set_rule_group - + rule_group_context["rule-group context"]:::context firewall_set_rule_group --> rule_group_context rule_group_show_rule_group["show rule-group"]:::show diff --git a/scripts/automated_issue_resolver.py b/scripts/automated_issue_resolver.py old mode 100644 new mode 100755 index 8bdedc4..390fddd --- a/scripts/automated_issue_resolver.py +++ b/scripts/automated_issue_resolver.py @@ -26,12 +26,13 @@ import sys from pathlib import Path from dataclasses import dataclass -from typing import List, Dict, Any +from typing import List +from dataclasses import asdict try: from rich.console import Console - from rich.progress import Progress, SpinnerColumn, TextColumn from rich.panel import Panel + from rich.table import Table except ImportError: print("Rich required. Install: pip install rich") sys.exit(1) @@ -42,6 +43,7 @@ @dataclass class IssueResolutionResult: """Result of attempting to resolve an issue.""" + issue_number: int issue_title: str agent_prompt_generated: bool @@ -80,7 +82,7 @@ def resolve_issue(self, issue_number: int) -> IssueResolutionResult: fix_attempted=False, tests_created=False, tests_passed=False, - pr_created=False + pr_created=False, ) try: @@ -90,12 +92,16 @@ def resolve_issue(self, issue_number: int) -> IssueResolutionResult: result.agent_prompt_generated = True if self.dry_run: - console.print(f"[dim]Dry run: Would execute agent prompt[/]") - console.print(Panel(agent_prompt[:500] + "...", title="Agent Prompt Preview")) + console.print("[dim]Dry run: Would execute agent prompt[/]") + console.print( + Panel(agent_prompt[:500] + "...", title="Agent Prompt Preview") + ) return result # Step 2: Execute agent prompt to implement fix - console.print("[yellow]Step 2: Executing agent prompt to implement fix...[/]") + console.print( + "[yellow]Step 2: Executing agent prompt to implement fix...[/]" + ) fix_applied = self._execute_agent_prompt(agent_prompt, issue_number) result.fix_attempted = True @@ -130,11 +136,15 @@ def resolve_issue(self, issue_number: int) -> IssueResolutionResult: def _generate_agent_prompt(self, issue_number: int) -> str: """Generate agent prompt using issue_investigator.py.""" cmd = [ - "uv", "run", "python", + "uv", + "run", + "python", str(self.scripts_dir / "issue_investigator.py"), - "--issue", str(issue_number), + "--issue", + str(issue_number), "--agent-prompt", - "--format", "xml" + "--format", + "xml", ] result = subprocess.run(cmd, capture_output=True, text=True, cwd=self.repo_root) @@ -156,7 +166,9 @@ def _execute_agent_prompt(self, agent_prompt: str, issue_number: int) -> bool: prompt_file.write_text(agent_prompt) console.print(f"[green]✓[/] Agent prompt saved to: {prompt_file}") - console.print("[dim]Execute this prompt with your AI agent to implement the fix[/]") + console.print( + "[dim]Execute this prompt with your AI agent to implement the fix[/]" + ) # TODO: Integrate with Claude Code API or other agent system # For now, this is a manual step @@ -166,27 +178,40 @@ def _create_issue_test(self, issue_number: int, agent_prompt: str) -> bool: """Create a test that validates the issue is fixed.""" # Extract workflow from agent prompt # Create YAML workflow file - workflow_file = self.repo_root / f"tests/integration/workflows/issue_{issue_number}_automated.yaml" + _workflow_file = ( + self.repo_root + / f"tests/integration/workflows/issue_{issue_number}_automated.yaml" + ) # TODO: Parse agent prompt to extract command sequence # For now, use issue_tests.yaml if exists + _ = _workflow_file # Placeholder for future implementation return False - def _run_tests_iteratively(self, issue_number: int, max_iterations: int = 3) -> bool: + def _run_tests_iteratively( + self, issue_number: int, max_iterations: int = 3 + ) -> bool: """Run tests iteratively, applying fixes on failures.""" for iteration in range(max_iterations): console.print(f"[cyan]Test iteration {iteration + 1}/{max_iterations}[/]") # Run test for this specific issue cmd = [ - ".venv/bin/python", "-m", "pytest", - f"tests/integration/test_workflows.py", - "-k", f"issue_{issue_number}", - "-v", "--tb=short", "--override-ini=addopts=" + ".venv/bin/python", + "-m", + "pytest", + "tests/integration/test_workflows.py", + "-k", + f"issue_{issue_number}", + "-v", + "--tb=short", + "--override-ini=addopts=", ] - result = subprocess.run(cmd, capture_output=True, text=True, cwd=self.repo_root) + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=self.repo_root + ) if result.returncode == 0: console.print(f"[green]✓ Tests passed on iteration {iteration + 1}[/]") @@ -201,10 +226,11 @@ def _run_tests_iteratively(self, issue_number: int, max_iterations: int = 3) -> def _create_pull_request(self, issue_number: int) -> bool: """Create a pull request with the fix.""" - branch_name = f"fix/issue-{issue_number}-automated" + _branch_name = f"fix/issue-{issue_number}-automated" # TODO: Use gh CLI to create PR # gh pr create --title "Fix Issue #{issue_number}" --body "Automated fix" + _ = _branch_name # Placeholder for future implementation console.print("[dim]PR creation would happen here[/]") return False @@ -212,8 +238,17 @@ def _create_pull_request(self, issue_number: int) -> bool: def resolve_all_open_issues(self) -> List[IssueResolutionResult]: """Attempt to resolve all open issues.""" # Fetch open issues - cmd = ["gh", "issue", "list", "--repo", "NetDevAutomate/aws_network_shell", - "--state", "open", "--json", "number,title"] + cmd = [ + "gh", + "issue", + "list", + "--repo", + "NetDevAutomate/aws_network_shell", + "--state", + "open", + "--json", + "number,title", + ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: @@ -234,11 +269,13 @@ def main(): parser = argparse.ArgumentParser( description="Automated issue resolution using agent prompts", formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__ + epilog=__doc__, ) parser.add_argument("--issue", type=int, help="Specific issue number to resolve") parser.add_argument("--all", action="store_true", help="Resolve all open issues") - parser.add_argument("--dry-run", action="store_true", help="Preview actions without executing") + parser.add_argument( + "--dry-run", action="store_true", help="Preview actions without executing" + ) parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") parser.add_argument("--output", "-o", help="Save results to JSON file") @@ -271,7 +308,7 @@ def main(): "✓" if r.agent_prompt_generated else "✗", "✓" if r.fix_attempted else "✗", "✓" if r.tests_passed else "✗", - "✓" if r.pr_created else "✗" + "✓" if r.pr_created else "✗", ) console.print(table) diff --git a/scripts/clean-output.py b/scripts/clean-output.py index 98deec9..3b8ef7f 100755 --- a/scripts/clean-output.py +++ b/scripts/clean-output.py @@ -30,29 +30,45 @@ def clean_output(text: str, compact: bool = False) -> str: Returns: Cleaned text suitable for markdown code blocks """ - lines = text.split('\n') + lines = text.split("\n") cleaned = [] for line in lines: # Remove ANSI escape sequences (colors, cursor movements) - line = re.sub(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])', '', line) + line = re.sub(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])", "", line) # Replace box drawing characters with simple ASCII box_chars = { - '┏': '+', '┓': '+', '┗': '+', '┛': '+', - '┣': '+', '┫': '+', '┳': '+', '┻': '+', '╋': '+', - '├': '+', '┤': '+', '┬': '+', '┴': '+', - '┃': '|', '│': '|', - '━': '-', '─': '-', - '┡': '+', '┩': '+', - '╭': '+', '╮': '+', '╰': '+', '╯': '+', - '┼': '+', + "┏": "+", + "┓": "+", + "┗": "+", + "┛": "+", + "┣": "+", + "┫": "+", + "┳": "+", + "┻": "+", + "╋": "+", + "├": "+", + "┤": "+", + "┬": "+", + "┴": "+", + "┃": "|", + "│": "|", + "━": "-", + "─": "-", + "┡": "+", + "┩": "+", + "╭": "+", + "╮": "+", + "╰": "+", + "╯": "+", + "┼": "+", } for char, replacement in box_chars.items(): line = line.replace(char, replacement) # Normalize multiple spaces (but preserve indentation) - line = re.sub(r'(?<=\S) +', ' ', line) + line = re.sub(r"(?<=\S) +", " ", line) # Remove trailing whitespace line = line.rstrip() @@ -60,23 +76,23 @@ def clean_output(text: str, compact: bool = False) -> str: cleaned.append(line) # Join lines - result = '\n'.join(cleaned) + result = "\n".join(cleaned) if compact: # Remove multiple consecutive blank lines, keep max 1 - result = re.sub(r'\n\n\n+', '\n\n', result) + result = re.sub(r"\n\n\n+", "\n\n", result) # Remove leading/trailing blank lines result = result.strip() else: # Just remove excessive blank lines (keep max 2) - result = re.sub(r'\n\n\n+', '\n\n', result) + result = re.sub(r"\n\n\n+", "\n\n", result) return result def main(): parser = argparse.ArgumentParser( - description='Clean terminal output for git commit messages', + description="Clean terminal output for git commit messages", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: @@ -91,19 +107,20 @@ def main(): # From file python scripts/clean-output.py < terminal-output.txt > cleaned.txt - """ + """, ) parser.add_argument( - '--compact', '-c', - action='store_true', - help='Remove all blank lines for compact output' + "--compact", + "-c", + action="store_true", + help="Remove all blank lines for compact output", ) parser.add_argument( - 'input_file', - nargs='?', - type=argparse.FileType('r'), + "input_file", + nargs="?", + type=argparse.FileType("r"), default=sys.stdin, - help='Input file (default: stdin)' + help="Input file (default: stdin)", ) args = parser.parse_args() @@ -116,5 +133,5 @@ def main(): print(cleaned) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/fetch_issues.py b/scripts/fetch_issues.py old mode 100644 new mode 100755 index b509c3a..f60ce82 --- a/scripts/fetch_issues.py +++ b/scripts/fetch_issues.py @@ -48,45 +48,52 @@ def extract_commands(body: str) -> list[str]: """Extract shell commands from issue body.""" commands = [] lines = body.split("\n") - + for line in lines: line = line.strip() # Match lines that look like shell commands # e.g., "aws-net> show vpcs" or "aws-net/tr:xxx> show routes" if "aws-net" in line and ">" in line: # Extract command after the prompt - match = re.search(r'[>$]\s*(.+)$', line) + match = re.search(r"[>$]\s*(.+)$", line) if match: cmd = match.group(1).strip() - if cmd and not cmd.startswith(("EXCEPTION", "Error", "┏", "┃", "┡", "│", "└")): + if cmd and not cmd.startswith( + ("EXCEPTION", "Error", "┏", "┃", "┡", "│", "└") + ): commands.append(cmd) - + return commands def format_yaml(issues: list[dict]) -> str: """Format issues as YAML test definitions.""" - lines = ["# Auto-generated from GitHub issues", "# Review and adjust commands as needed", "", "issues:"] - + lines = [ + "# Auto-generated from GitHub issues", + "# Review and adjust commands as needed", + "", + "issues:", + ] + for issue in issues: num = issue["number"] title = issue["title"] body = issue.get("body", "") or "" commands = extract_commands(body) - + lines.append(f" {num}:") lines.append(f' title: "{title}"') lines.append(" commands:") - + if commands: for cmd in commands: lines.append(f" - {cmd}") else: lines.append(" # No commands extracted - add manually") lines.append(" - show global-networks") - + lines.append("") - + return "\n".join(lines) @@ -98,7 +105,7 @@ def format_commands(issues: list[dict]) -> str: title = issue["title"] body = issue.get("body", "") or "" commands = extract_commands(body) - + lines.append(f"# Issue #{num}: {title}") if commands: cmd_args = " ".join(f'"{c}"' for c in commands) @@ -106,15 +113,20 @@ def format_commands(issues: list[dict]) -> str: else: lines.append("# No commands extracted") lines.append("") - + return "\n".join(lines) def main(): parser = argparse.ArgumentParser(description="Fetch GitHub issues") parser.add_argument("--issue", "-i", type=int, help="Fetch specific issue") - parser.add_argument("--format", "-f", choices=["yaml", "commands", "json"], - default="yaml", help="Output format") + parser.add_argument( + "--format", + "-f", + choices=["yaml", "commands", "json"], + default="yaml", + help="Output format", + ) args = parser.parse_args() try: diff --git a/scripts/issue_investigator.py b/scripts/issue_investigator.py old mode 100644 new mode 100755 index 8600aa6..a663103 --- a/scripts/issue_investigator.py +++ b/scripts/issue_investigator.py @@ -47,7 +47,7 @@ from rich.console import Console from rich.table import Table from rich.panel import Panel - from rich.prompt import Prompt, IntPrompt + from rich.prompt import Prompt from rich.syntax import Syntax from rich.markdown import Markdown except ImportError: @@ -65,6 +65,7 @@ @dataclass class CommandResult: """Result of a single command execution.""" + command: str output: str duration_seconds: float = 0.0 @@ -76,6 +77,7 @@ class CommandResult: @dataclass class IssueInvestigation: """Complete investigation result for an issue.""" + issue_number: int issue_title: str issue_url: str @@ -97,7 +99,7 @@ def to_dict(self) -> dict: def to_agent_prompt(self, fmt: str = "xml") -> str: """Generate a structured prompt for an AI agent to work on this issue. - + Args: fmt: Output format - "xml" (recommended for agents) or "markdown" """ @@ -109,87 +111,97 @@ def _to_xml_prompt(self) -> str: """Generate XML-structured prompt (more efficient for AI agents).""" lines = [] lines.append("") - - lines.append(f" ") + + lines.append(f' ') lines.append(f" {self.issue_title}") lines.append(f" {self.issue_url}") lines.append(f" {self.status}") lines.append(f" {self.reproduced}") lines.append(" ") - + lines.append(" ") lines.append(self.issue_body or "No description provided") lines.append(" ") - + if self.commands_run: lines.append(" ") for i, result in enumerate(self.commands_run, 1): - lines.append(f" ") + lines.append(f' ') lines.append(f" {result.command}") if result.has_error: - lines.append(f" {result.error_message}") - output = result.output[:1500] + ('...' if len(result.output) > 1500 else '') + lines.append( + f' {result.error_message}' + ) + output = result.output[:1500] + ( + "..." if len(result.output) > 1500 else "" + ) lines.append(f" ") - lines.append(f" {result.duration_seconds:.2f}") + lines.append( + f" {result.duration_seconds:.2f}" + ) lines.append(" ") lines.append(" ") - + if self.actual_errors: lines.append(" ") for error in self.actual_errors: lines.append(f" {error}") lines.append(" ") - + if self.debug_info: lines.append(" ") for key, value in self.debug_info.items(): lines.append(f" <{key}>{value}") lines.append(" ") - + lines.append(" ") if self.reproduced: - lines.append(""" Fix the confirmed issue + lines.append( + """ Fix the confirmed issue Analyze the error messages and stack traces in commands_executed Search the codebase for relevant code handling these commands Identify the root cause of the issue Propose and implement a fix Add a test case to prevent regression - """) + """ + ) else: - lines.append(""" Investigate why issue could not be reproduced + lines.append( + """ Investigate why issue could not be reproduced Review the commands and output above Check if the issue might be environment-specific Look for any partial failures or warnings Determine if the issue was already fixed or needs different reproduction steps Update the issue status accordingly - """) + """ + ) lines.append(" ") - + if self.recommendations: lines.append(" ") for rec in self.recommendations: lines.append(f" {rec}") lines.append(" ") - + lines.append("") return "\n".join(lines) def _to_markdown_prompt(self) -> str: """Generate markdown prompt (better for human readability).""" sections = [] - + sections.append(f"# GitHub Issue #{self.issue_number}: {self.issue_title}") sections.append(f"\n**URL:** {self.issue_url}") sections.append(f"**Status:** {self.status.upper()}") sections.append(f"**Reproduced:** {'Yes' if self.reproduced else 'No'}") - + sections.append("\n## Original Issue Description") sections.append(self.issue_body or "_No description provided_") - + sections.append("\n## Investigation Results") - + if self.commands_run: sections.append("\n### Commands Executed") for i, result in enumerate(self.commands_run, 1): @@ -197,43 +209,49 @@ def _to_markdown_prompt(self) -> str: if result.has_error: sections.append(f"- Error Type: `{result.error_type}`") sections.append(f"- Error Message: `{result.error_message}`") - sections.append(f"```\n{result.output[:1000]}{'...' if len(result.output) > 1000 else ''}\n```") - + sections.append( + f"```\n{result.output[:1000]}{'...' if len(result.output) > 1000 else ''}\n```" + ) + if self.actual_errors: sections.append("\n### Errors Detected") for error in self.actual_errors: sections.append(f"- `{error}`") - + if self.debug_info: sections.append("\n### Debug Information") for key, value in self.debug_info.items(): sections.append(f"- **{key}:** {value}") - + sections.append("\n## Task for Agent") if self.reproduced: - sections.append(""" + sections.append( + """ The issue has been **confirmed reproducible**. Please: 1. Analyze the error messages and stack traces above 2. Search the codebase for relevant code handling these commands 3. Identify the root cause of the issue 4. Propose and implement a fix 5. Add a test case to prevent regression -""") +""" + ) else: - sections.append(""" + sections.append( + """ The issue could **not be reproduced**. Please: 1. Review the commands and output above 2. Check if the issue might be environment-specific 3. Look for any partial failures or warnings 4. Determine if the issue was already fixed or needs different reproduction steps 5. Update the issue status accordingly -""") - +""" + ) + if self.recommendations: sections.append("\n### Recommendations") for rec in self.recommendations: sections.append(f"- {rec}") - + return "\n".join(sections) @@ -266,21 +284,23 @@ def extract_commands_from_body(body: str) -> list[str]: """Extract shell commands from issue body.""" if not body: return [] - + commands = [] lines = body.split("\n") - + for line in lines: line = line.strip() # Match lines that look like shell commands # e.g., "aws-net> show vpcs" or "aws-net/tr:xxx> show routes" if "aws-net" in line and ">" in line: - match = re.search(r'[>$]\s*(.+)$', line) + match = re.search(r"[>$]\s*(.+)$", line) if match: cmd = match.group(1).strip() - if cmd and not cmd.startswith(("EXCEPTION", "Error", "┏", "┃", "┡", "│", "└", "No ", "Use ")): + if cmd and not cmd.startswith( + ("EXCEPTION", "Error", "┏", "┃", "┡", "│", "└", "No ", "Use ") + ): commands.append(cmd) - + return commands @@ -288,19 +308,24 @@ def extract_errors_from_body(body: str) -> list[str]: """Extract expected error patterns from issue body.""" if not body: return [] - + errors = [] - + # Look for EXCEPTION patterns - exception_matches = re.findall(r"EXCEPTION of type '(\w+)' occurred with message: (.+?)(?:\n|$)", body) + exception_matches = re.findall( + r"EXCEPTION of type '(\w+)' occurred with message: (.+?)(?:\n|$)", body + ) for exc_type, message in exception_matches: errors.append(f"{exc_type}: {message.strip()}") - + # Look for KeyError, TypeError, etc. - error_patterns = re.findall(r"(KeyError|TypeError|ValueError|AttributeError)[:\s]+['\"]?([^'\"]+)['\"]?", body) + error_patterns = re.findall( + r"(KeyError|TypeError|ValueError|AttributeError)[:\s]+['\"]?([^'\"]+)['\"]?", + body, + ) for error_type, detail in error_patterns: errors.append(f"{error_type}: {detail.strip()}") - + # Look for common error messages if "No data returned" in body or "No policy data" in body: errors.append("No data returned") @@ -308,92 +333,105 @@ def extract_errors_from_body(body: str) -> list[str]: match = re.search(r"Run '(show [^']+)' first", body) if match: errors.append(f"Context error: {match.group(0)}") - + return list(set(errors)) def detect_errors_in_output(output: str) -> tuple[bool, str | None, str | None]: """Detect if output contains an error.""" output_lower = output.lower() - + # Check for EXCEPTION - exc_match = re.search(r"EXCEPTION of type '(\w+)' occurred with message: (.+?)(?:\n|$)", output) + exc_match = re.search( + r"EXCEPTION of type '(\w+)' occurred with message: (.+?)(?:\n|$)", output + ) if exc_match: return True, exc_match.group(1), exc_match.group(2).strip() - + # Check for common Python exceptions in output - for exc_type in ["KeyError", "TypeError", "ValueError", "AttributeError", "IndexError"]: + for exc_type in [ + "KeyError", + "TypeError", + "ValueError", + "AttributeError", + "IndexError", + ]: if exc_type in output: match = re.search(rf"{exc_type}[:\s]+['\"]?([^'\"]+)['\"]?", output) if match: return True, exc_type, match.group(1).strip() return True, exc_type, "Unknown details" - + # Check for "Invalid:" messages if "Invalid:" in output: - return True, "InvalidCommand", output.split("Invalid:")[1].split("\n")[0].strip() - + return ( + True, + "InvalidCommand", + output.split("Invalid:")[1].split("\n")[0].strip(), + ) + # Check for traceback if "traceback" in output_lower: return True, "Traceback", "Python traceback detected" - + return False, None, None def display_issues_table(issues: list[dict]) -> None: """Display issues in a formatted table.""" - table = Table(title="Open GitHub Issues", show_header=True, header_style="bold cyan") + table = Table( + title="Open GitHub Issues", show_header=True, header_style="bold cyan" + ) table.add_column("#", style="dim", width=4) table.add_column("Title", style="white", max_width=50) table.add_column("Created", style="dim", width=12) table.add_column("Commands", width=8, justify="center") - + for issue in issues: body = issue.get("body", "") or "" commands = extract_commands_from_body(body) created = issue["created_at"][:10] - + table.add_row( str(issue["number"]), issue["title"][:50], created, - str(len(commands)) if commands else "-" + str(len(commands)) if commands else "-", ) - + console.print(table) def select_issue_interactive(issues: list[dict]) -> dict | None: """Interactively select an issue from the list.""" display_issues_table(issues) - + console.print("\n[dim]Enter issue number to investigate, or 'q' to quit[/dim]") - + issue_nums = [i["number"] for i in issues] while True: choice = Prompt.ask("Select issue", default="q") - if choice.lower() == 'q': + if choice.lower() == "q": return None try: num = int(choice) if num in issue_nums: return next(i for i in issues if i["number"] == num) - console.print(f"[yellow]Issue #{num} not in list. Valid: {issue_nums}[/yellow]") + console.print( + f"[yellow]Issue #{num} not in list. Valid: {issue_nums}[/yellow]" + ) except ValueError: console.print("[yellow]Please enter a number or 'q'[/yellow]") def investigate_issue( - issue: dict, - profile: str | None = None, - timeout: int = 60, - verbose: bool = False + issue: dict, profile: str | None = None, timeout: int = 60, verbose: bool = False ) -> IssueInvestigation: """Investigate a single issue by attempting to reproduce it.""" - + issue_num = issue["number"] body = issue.get("body", "") or "" - + investigation = IssueInvestigation( issue_number=issue_num, issue_title=issue["title"], @@ -401,89 +439,96 @@ def investigate_issue( issue_body=body, investigation_time=datetime.now().isoformat(), reproduced=False, - status="pending" + status="pending", ) - + # Extract commands and expected errors from issue body investigation.extracted_commands = extract_commands_from_body(body) investigation.expected_errors = extract_errors_from_body(body) - + if not investigation.extracted_commands: - console.print(f"[yellow]⚠️ No commands found in issue #{issue_num} body[/yellow]") + console.print( + f"[yellow]⚠️ No commands found in issue #{issue_num} body[/yellow]" + ) investigation.status = "no_commands" - investigation.recommendations.append("Manually review issue and add test commands to issue_tests.yaml") + investigation.recommendations.append( + "Manually review issue and add test commands to issue_tests.yaml" + ) return investigation - - console.print(Panel( - f"[bold]Issue #{issue_num}:[/bold] {issue['title']}\n\n" - f"[dim]Commands to run: {len(investigation.extracted_commands)}[/dim]\n" - f"[dim]Expected errors: {len(investigation.expected_errors)}[/dim]", - title="🔍 Investigation Starting" - )) - + + console.print( + Panel( + f"[bold]Issue #{issue_num}:[/bold] {issue['title']}\n\n" + f"[dim]Commands to run: {len(investigation.extracted_commands)}[/dim]\n" + f"[dim]Expected errors: {len(investigation.expected_errors)}[/dim]", + title="🔍 Investigation Starting", + ) + ) + if verbose: console.print("\n[dim]Extracted commands:[/dim]") for cmd in investigation.extracted_commands: console.print(f" [cyan]{cmd}[/cyan]") - + # Run the commands runner = ShellRunner(profile=profile, timeout=timeout) all_output = "" - + try: runner.start() - + for cmd in investigation.extracted_commands: import time + start_time = time.time() - + try: output = runner.run(cmd) duration = time.time() - start_time except Exception as e: output = f"RUNNER ERROR: {e}\n{traceback.format_exc()}" duration = time.time() - start_time - + all_output += f"\n> {cmd}\n{output}\n" - + # Check for errors in output has_error, error_type, error_msg = detect_errors_in_output(output) - + result = CommandResult( command=cmd, output=output, duration_seconds=duration, has_error=has_error, error_type=error_type, - error_message=error_msg + error_message=error_msg, ) investigation.commands_run.append(result) - + if has_error and error_type: investigation.actual_errors.append(f"{error_type}: {error_msg}") - + except Exception as e: investigation.debug_info["runner_error"] = str(e) investigation.debug_info["runner_traceback"] = traceback.format_exc() investigation.status = "error" finally: runner.close() - + investigation.raw_output = all_output - + # Analyze results _analyze_investigation(investigation) - + return investigation def _analyze_investigation(investigation: IssueInvestigation) -> None: """Analyze the investigation results and set status.""" - + # Check if any expected errors were found errors_found = len(investigation.actual_errors) > 0 expected_matched = False - + for expected in investigation.expected_errors: for actual in investigation.actual_errors: # Fuzzy match - check if key parts match @@ -492,81 +537,105 @@ def _analyze_investigation(investigation: IssueInvestigation) -> None: if any(part in actual_lower for part in expected_lower.split(":")): expected_matched = True break - + # Determine status if investigation.status == "error": - investigation.recommendations.append("Investigation encountered an error - check runner configuration") + investigation.recommendations.append( + "Investigation encountered an error - check runner configuration" + ) elif expected_matched: investigation.reproduced = True investigation.status = "confirmed" - investigation.recommendations.append("Issue is confirmed - analyze the error and fix the root cause") + investigation.recommendations.append( + "Issue is confirmed - analyze the error and fix the root cause" + ) elif errors_found: investigation.reproduced = True investigation.status = "confirmed" - investigation.recommendations.append("Errors detected (different from expected) - issue likely exists but symptoms may have changed") + investigation.recommendations.append( + "Errors detected (different from expected) - issue likely exists but symptoms may have changed" + ) elif investigation.expected_errors: investigation.status = "not_reproducible" - investigation.recommendations.append("Expected errors not found - issue may be fixed or environment-specific") + investigation.recommendations.append( + "Expected errors not found - issue may be fixed or environment-specific" + ) else: investigation.status = "partial" - investigation.recommendations.append("No explicit errors expected or found - manual review of output recommended") - + investigation.recommendations.append( + "No explicit errors expected or found - manual review of output recommended" + ) + # Add recommendations based on error types for error in investigation.actual_errors: if "KeyError" in error: - investigation.recommendations.append("KeyError suggests missing dictionary key - check API response structure") + investigation.recommendations.append( + "KeyError suggests missing dictionary key - check API response structure" + ) elif "TypeError" in error: - investigation.recommendations.append("TypeError suggests type mismatch - check data types in comparison/operations") + investigation.recommendations.append( + "TypeError suggests type mismatch - check data types in comparison/operations" + ) elif "AttributeError" in error: - investigation.recommendations.append("AttributeError suggests accessing missing attribute - check object structure") + investigation.recommendations.append( + "AttributeError suggests accessing missing attribute - check object structure" + ) elif "InvalidCommand" in error: - investigation.recommendations.append("Invalid command - check command registration and available commands") + investigation.recommendations.append( + "Invalid command - check command registration and available commands" + ) -def display_investigation_results(investigation: IssueInvestigation, verbose: bool = False) -> None: +def display_investigation_results( + investigation: IssueInvestigation, verbose: bool = False +) -> None: """Display investigation results in a formatted way.""" - + status_styles = { "confirmed": "[red]❌ CONFIRMED - Issue exists[/red]", "not_reproducible": "[green]✅ NOT REPRODUCIBLE - May be fixed[/green]", "partial": "[yellow]⚠️ PARTIAL - Manual review needed[/yellow]", "error": "[red]💥 ERROR - Investigation failed[/red]", - "no_commands": "[yellow]📝 NO COMMANDS - Cannot auto-investigate[/yellow]" + "no_commands": "[yellow]📝 NO COMMANDS - Cannot auto-investigate[/yellow]", } - + console.print("\n") - console.print(Panel( - f"[bold]Issue #{investigation.issue_number}:[/bold] {investigation.issue_title}\n\n" - f"Status: {status_styles.get(investigation.status, investigation.status)}", - title="📊 Investigation Results" - )) - + console.print( + Panel( + f"[bold]Issue #{investigation.issue_number}:[/bold] {investigation.issue_title}\n\n" + f"Status: {status_styles.get(investigation.status, investigation.status)}", + title="📊 Investigation Results", + ) + ) + if investigation.actual_errors: console.print("\n[bold red]Errors Detected:[/bold red]") for error in investigation.actual_errors: console.print(f" • {error}") - + if investigation.recommendations: console.print("\n[bold cyan]Recommendations:[/bold cyan]") for rec in investigation.recommendations: console.print(f" → {rec}") - + if verbose and investigation.raw_output: console.print("\n[bold]Raw Output:[/bold]") console.print(Panel(investigation.raw_output[:2000], title="Shell Output")) -def save_investigation(investigation: IssueInvestigation, output_path: Path, fmt: str = "xml") -> None: +def save_investigation( + investigation: IssueInvestigation, output_path: Path, fmt: str = "xml" +) -> None: """Save investigation results to a file.""" data = { "investigation": investigation.to_dict(), "agent_prompt_xml": investigation.to_agent_prompt(fmt="xml"), "agent_prompt_markdown": investigation.to_agent_prompt(fmt="markdown"), } - - with open(output_path, 'w') as f: + + with open(output_path, "w") as f: json.dump(data, f, indent=2, default=str) - + console.print(f"\n[green]✓ Investigation saved to: {output_path}[/green]") @@ -574,18 +643,35 @@ def main(): parser = argparse.ArgumentParser( description="Investigate GitHub issues for aws-net-shell", formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__ + epilog=__doc__, + ) + parser.add_argument( + "--issue", "-i", type=int, help="Specific issue number to investigate" ) - parser.add_argument("--issue", "-i", type=int, help="Specific issue number to investigate") parser.add_argument("--profile", "-p", help="AWS profile to use") - parser.add_argument("--timeout", "-t", type=int, default=60, help="Command timeout (default: 60s)") - parser.add_argument("--output", "-o", type=Path, help="Save investigation to JSON file") + parser.add_argument( + "--timeout", "-t", type=int, default=60, help="Command timeout (default: 60s)" + ) + parser.add_argument( + "--output", "-o", type=Path, help="Save investigation to JSON file" + ) parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") - parser.add_argument("--agent-prompt", "-a", action="store_true", - help="Print agent-ready prompt to stdout") - parser.add_argument("--format", "-f", choices=["xml", "markdown"], default="xml", - help="Agent prompt format: xml (default, better for agents) or markdown") - parser.add_argument("--list", "-l", action="store_true", help="Just list open issues") + parser.add_argument( + "--agent-prompt", + "-a", + action="store_true", + help="Print agent-ready prompt to stdout", + ) + parser.add_argument( + "--format", + "-f", + choices=["xml", "markdown"], + default="xml", + help="Agent prompt format: xml (default, better for agents) or markdown", + ) + parser.add_argument( + "--list", "-l", action="store_true", help="Just list open issues" + ) args = parser.parse_args() try: @@ -595,47 +681,47 @@ def main(): issue = issues[0] else: issues = fetch_issues() - + if args.list: display_issues_table(issues) sys.exit(0) - + issue = select_issue_interactive(issues) if not issue: console.print("[dim]Cancelled[/dim]") sys.exit(0) - + # Investigate the issue investigation = investigate_issue( issue=issue, profile=args.profile, timeout=args.timeout, - verbose=args.verbose + verbose=args.verbose, ) - + # Display results display_investigation_results(investigation, verbose=args.verbose) - + # Save to file if requested if args.output: save_investigation(investigation, args.output) - + # Print agent prompt if requested if args.agent_prompt: - console.print("\n" + "="*60) + console.print("\n" + "=" * 60) console.print(f"[bold]AGENT PROMPT ({args.format.upper()})[/bold]") - console.print("="*60 + "\n") + console.print("=" * 60 + "\n") prompt = investigation.to_agent_prompt(fmt=args.format) if args.format == "markdown": console.print(Markdown(prompt)) else: console.print(Syntax(prompt, "xml", theme="monokai")) - + # Exit code based on status if investigation.reproduced: sys.exit(1) # Issue exists sys.exit(0) - + except KeyboardInterrupt: console.print("\n[dim]Interrupted[/dim]") sys.exit(130) diff --git a/scripts/run_issue_tests.py b/scripts/run_issue_tests.py old mode 100644 new mode 100755 index a1a7162..feaea14 --- a/scripts/run_issue_tests.py +++ b/scripts/run_issue_tests.py @@ -44,9 +44,9 @@ def print_commands(issue: dict): def run_issue_test(runner: ShellRunner, issue_num: int, issue: dict) -> bool: """Run a single issue test and check results.""" - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"ISSUE #{issue_num}: {issue.get('title', 'Untitled')}") - print(f"{'='*60}") + print(f"{'=' * 60}") outputs = [] all_output = "" @@ -61,13 +61,15 @@ def run_issue_test(runner: ShellRunner, issue_num: int, issue: dict) -> bool: # Check expectations passed = True - + # Check for expected error if "expect_error" in issue: if issue["expect_error"] not in all_output: print(f"\n⚠️ Expected error not found: {issue['expect_error']}") else: - print(f"\n❌ CONFIRMED: Error '{issue['expect_error']}' present (issue exists)") + print( + f"\n❌ CONFIRMED: Error '{issue['expect_error']}' present (issue exists)" + ) passed = False # Check for strings that should be present (indicating bug) @@ -86,7 +88,7 @@ def run_issue_test(runner: ShellRunner, issue_num: int, issue: dict) -> bool: if passed: print(f"\n✅ Issue #{issue_num} appears FIXED or not reproducible") - + return passed @@ -95,10 +97,16 @@ def main(): parser.add_argument("--issue", "-i", type=int, help="Run specific issue number") parser.add_argument("--profile", "-p", help="AWS profile to use") parser.add_argument("--timeout", "-t", type=int, default=60, help="Command timeout") - parser.add_argument("--print-commands", action="store_true", - help="Just print commands for shell_runner.py") - parser.add_argument("--yaml", default=Path(__file__).parent / "issue_tests.yaml", - help="Path to issue tests YAML file") + parser.add_argument( + "--print-commands", + action="store_true", + help="Just print commands for shell_runner.py", + ) + parser.add_argument( + "--yaml", + default=Path(__file__).parent / "issue_tests.yaml", + help="Path to issue tests YAML file", + ) args = parser.parse_args() issues = load_issues(Path(args.yaml)) @@ -135,17 +143,17 @@ def main(): runner.close() # Summary - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("SUMMARY") - print(f"{'='*60}") - + print(f"{'=' * 60}") + passed = sum(1 for v in results.values() if v) failed = sum(1 for v in results.values() if not v) - + for issue_num, result in results.items(): status = "✅ FIXED" if result else "❌ EXISTS" print(f" Issue #{issue_num}: {status}") - + print(f"\nTotal: {passed} fixed, {failed} still exist") sys.exit(0 if failed == 0 else 1) diff --git a/scripts/s2svpn b/scripts/s2svpn index d752fc2..9a68834 100755 --- a/scripts/s2svpn +++ b/scripts/s2svpn @@ -3,7 +3,6 @@ import argparse import json -import os import sys import time from pathlib import Path @@ -12,7 +11,6 @@ try: import boto3 from rich.console import Console from rich.table import Table - from rich.panel import Panel except ImportError: print("Missing dependencies. Run: pip install boto3 rich") sys.exit(1) @@ -24,7 +22,9 @@ RESOURCES_FILE = Path(__file__).parent.parent / "terraform" / "test_resources.js def load_resources(): if not RESOURCES_FILE.exists(): - console.print("[red]Error: test_resources.json not found. Run terraform apply first.[/red]") + console.print( + "[red]Error: test_resources.json not found. Run terraform apply first.[/red]" + ) sys.exit(1) return json.loads(RESOURCES_FILE.read_text()) @@ -40,14 +40,14 @@ def cmd_status(args): """Show VPN and tunnel status.""" session = get_session(args.profile) ec2 = session.client("ec2") - + if args.all: # Show all VPN connections vpns = ec2.describe_vpn_connections()["VpnConnections"] if not vpns: console.print("[yellow]No VPN connections found[/yellow]") return - + table = Table(title="All VPN Connections", show_header=True) table.add_column("#", style="dim") table.add_column("VPN ID", style="cyan") @@ -55,43 +55,55 @@ def cmd_status(args): table.add_column("State", style="bold") table.add_column("Gateway ID", style="blue") table.add_column("Tunnels", style="white", justify="right") - + for i, vpn in enumerate(vpns, 1): - name = next((t["Value"] for t in vpn.get("Tags", []) if t["Key"] == "Name"), "N/A") + name = next( + (t["Value"] for t in vpn.get("Tags", []) if t["Key"] == "Name"), "N/A" + ) gw_id = vpn.get("TransitGatewayId") or vpn.get("VpnGatewayId") or "N/A" tunnel_count = len(vpn.get("VgwTelemetry", [])) - + state = vpn["State"] style = "green" if state == "available" else "yellow" - + table.add_row( str(i), vpn["VpnConnectionId"], name, f"[{style}]{state}[/{style}]", gw_id, - str(tunnel_count) + str(tunnel_count), ) - + console.print(table) - console.print("\n[dim]To see details for a specific VPN, update test_resources.json or use:[/dim]") + console.print( + "\n[dim]To see details for a specific VPN, update test_resources.json or use:[/dim]" + ) console.print("[dim] terraform output vpn_id # Get current VPN ID[/dim]") return - + # Try specific VPN from resources file try: res = load_resources() vpn_id = res.get("strongswan_vpn_id") instance_id = res.get("strongswan_instance_id") - + if not vpn_id: - console.print("[red]Error: No strongswan_vpn_id in test_resources.json[/red]") - console.print("[yellow]Hint: Run with --all to see all VPN connections, or update test_resources.json[/yellow]") - console.print("[dim] terraform output vpn_id # Get current VPN ID from Terraform[/dim]") + console.print( + "[red]Error: No strongswan_vpn_id in test_resources.json[/red]" + ) + console.print( + "[yellow]Hint: Run with --all to see all VPN connections, or update test_resources.json[/yellow]" + ) + console.print( + "[dim] terraform output vpn_id # Get current VPN ID from Terraform[/dim]" + ) return - + try: - vpn = ec2.describe_vpn_connections(VpnConnectionIds=[vpn_id])["VpnConnections"][0] + vpn = ec2.describe_vpn_connections(VpnConnectionIds=[vpn_id])[ + "VpnConnections" + ][0] except Exception as e: if "InvalidVpnConnectionID.NotFound" in str(e): console.print(f"[red]Error: VPN connection {vpn_id} not found[/red]") @@ -103,7 +115,9 @@ def cmd_status(args): console.print("[dim]Solutions:[/dim]") console.print(" 1. [dim]Run:[/dim] terraform output -raw vpn_id") console.print(" 2. Update test_resources.json with the new VPN ID") - console.print(" 3. Or run with [dim]./scripts/s2svpn status --all[/dim] to see all connections") + console.print( + " 3. Or run with [dim]./scripts/s2svpn status --all[/dim] to see all connections" + ) return else: raise @@ -111,28 +125,30 @@ def cmd_status(args): console.print("[red]Error: test_resources.json not found[/red]") console.print("[yellow]Run terraform apply first, or use --all flag[/yellow]") return - - inst = ec2.describe_instances(InstanceIds=[instance_id])["Reservations"][0]["Instances"][0] - + + inst = ec2.describe_instances(InstanceIds=[instance_id])["Reservations"][0][ + "Instances" + ][0] + table = Table(title="Site-to-Site VPN Status", show_header=True) table.add_column("Property", style="cyan") table.add_column("Value", style="green") - + table.add_row("VPN ID", vpn_id) table.add_row("VPN State", vpn["State"]) table.add_row("Instance ID", instance_id) table.add_row("Instance State", inst["State"]["Name"]) table.add_row("Public IP", res.get("strongswan_public_ip", "N/A")) table.add_row("Customer Gateway", vpn.get("CustomerGatewayId", "N/A")) - + console.print(table) - + tunnel_table = Table(title="Tunnel Status", show_header=True) tunnel_table.add_column("Tunnel", style="cyan") tunnel_table.add_column("Outside IP", style="white") tunnel_table.add_column("Status", style="bold") tunnel_table.add_column("Details") - + for i, telem in enumerate(vpn.get("VgwTelemetry", []), 1): status = telem["Status"] style = "green" if status == "UP" else "red" @@ -140,9 +156,9 @@ def cmd_status(args): f"Tunnel {i}", telem["OutsideIpAddress"], f"[{style}]{status}[/{style}]", - telem.get("StatusMessage", "") + telem.get("StatusMessage", ""), ) - + console.print(tunnel_table) @@ -152,21 +168,27 @@ def cmd_stop(args): res = load_resources() instance_id = res["strongswan_instance_id"] except (FileNotFoundError, KeyError): - console.print("[red]Error: strongswan_instance_id not in test_resources.json[/red]") - console.print("[yellow]Hint: Run terraform apply first, or specify --all for VPN status[/yellow]") + console.print( + "[red]Error: strongswan_instance_id not in test_resources.json[/red]" + ) + console.print( + "[yellow]Hint: Run terraform apply first, or specify --all for VPN status[/yellow]" + ) return - + session = get_session(args.profile) ec2 = session.client("ec2") - + # Verify instance exists try: ec2.describe_instances(InstanceIds=[instance_id]) except Exception: console.print(f"[red]Error: Instance {instance_id} not found[/red]") - console.print("[yellow]The strongSwan instance may have been terminated[/yellow]") + console.print( + "[yellow]The strongSwan instance may have been terminated[/yellow]" + ) return - + console.print(f"[yellow]Stopping instance {instance_id}...[/yellow]") ec2.stop_instances(InstanceIds=[instance_id]) console.print("[green]Stop initiated. VPN tunnels will go DOWN.[/green]") @@ -178,16 +200,22 @@ def cmd_start(args): res = load_resources() instance_id = res["strongswan_instance_id"] except (FileNotFoundError, KeyError): - console.print("[red]Error: strongswan_instance_id not in test_resources.json[/red]") - console.print("[yellow]Hint: Run terraform apply first, or specify --all for VPN status[/yellow]") + console.print( + "[red]Error: strongswan_instance_id not in test_resources.json[/red]" + ) + console.print( + "[yellow]Hint: Run terraform apply first, or specify --all for VPN status[/yellow]" + ) return - + session = get_session(args.profile) ec2 = session.client("ec2") - + console.print(f"[yellow]Starting instance {instance_id}...[/yellow]") ec2.start_instances(InstanceIds=[instance_id]) - console.print("[green]Start initiated. VPN tunnels will come UP in ~2-3 minutes.[/green]") + console.print( + "[green]Start initiated. VPN tunnels will come UP in ~2-3 minutes.[/green]" + ) def cmd_restart(args): @@ -196,62 +224,80 @@ def cmd_restart(args): res = load_resources() instance_id = res["strongswan_instance_id"] except (FileNotFoundError, KeyError): - console.print("[red]Error: strongswan_instance_id not in test_resources.json[/red]") - console.print("[yellow]Hint: Run terraform apply first, or specify --all for VPN status[/yellow]") + console.print( + "[red]Error: strongswan_instance_id not in test_resources.json[/red]" + ) + console.print( + "[yellow]Hint: Run terraform apply first, or specify --all for VPN status[/yellow]" + ) return - + session = get_session(args.profile) ec2 = session.client("ec2") - + # Verify instance exists try: ec2.describe_instances(InstanceIds=[instance_id]) except Exception: console.print(f"[red]Error: Instance {instance_id} not found[/red]") - console.print("[yellow]The strongSwan instance may have been terminated[/yellow]") + console.print( + "[yellow]The strongSwan instance may have been terminated[/yellow]" + ) return - + console.print(f"[yellow]Rebooting instance {instance_id}...[/yellow]") ec2.reboot_instances(InstanceIds=[instance_id]) - console.print("[green]Reboot initiated. VPN tunnels will reconnect in ~1-2 minutes.[/green]") + console.print( + "[green]Reboot initiated. VPN tunnels will reconnect in ~1-2 minutes.[/green]" + ) def cmd_monitor(args): """Live monitor VPN tunnel status.""" session = get_session(args.profile) ec2 = session.client("ec2") - + if args.all: # Monitor all VPN connections console.print("[cyan]Monitoring all VPN tunnels (Ctrl+C to stop)...[/cyan]\n") try: while True: vpns = ec2.describe_vpn_connections()["VpnConnections"] - + console.clear() for vpn in vpns: vpn_id = vpn["VpnConnectionId"] - name = next((t["Value"] for t in vpn.get("Tags", []) if t["Key"] == "Name"), "N/A") - - table = Table(title=f"VPN: {name} ({vpn_id}) - {time.strftime('%H:%M:%S')}") + name = next( + (t["Value"] for t in vpn.get("Tags", []) if t["Key"] == "Name"), + "N/A", + ) + + table = Table( + title=f"VPN: {name} ({vpn_id}) - {time.strftime('%H:%M:%S')}" + ) table.add_column("Tunnel") table.add_column("IP") table.add_column("Status") table.add_column("Message") - + for i, t in enumerate(vpn.get("VgwTelemetry", []), 1): status = t["Status"] style = "green" if status == "UP" else "red" - table.add_row(f"T{i}", t["OutsideIpAddress"], f"[{style}]{status}[/{style}]", t.get("StatusMessage", "")) - + table.add_row( + f"T{i}", + t["OutsideIpAddress"], + f"[{style}]{status}[/{style}]", + t.get("StatusMessage", ""), + ) + console.print(table) console.print() - + time.sleep(5) except KeyboardInterrupt: console.print("\n[yellow]Monitoring stopped.[/yellow]") return - + # Monitor specific VPN from resources try: res = load_resources() @@ -260,24 +306,31 @@ def cmd_monitor(args): console.print("[red]Error: strongswan_vpn_id not in test_resources.json[/red]") console.print("[yellow]Hint: Use --all to monitor all VPN connections[/yellow]") return - + console.print("[cyan]Monitoring VPN tunnels (Ctrl+C to stop)...[/cyan]\n") - + try: while True: - vpn = ec2.describe_vpn_connections(VpnConnectionIds=[vpn_id])["VpnConnections"][0] - + vpn = ec2.describe_vpn_connections(VpnConnectionIds=[vpn_id])[ + "VpnConnections" + ][0] + table = Table(title=f"VPN {vpn_id} - {time.strftime('%H:%M:%S')}") table.add_column("Tunnel") table.add_column("IP") table.add_column("Status") table.add_column("Message") - + for i, t in enumerate(vpn.get("VgwTelemetry", []), 1): status = t["Status"] style = "green" if status == "UP" else "red" - table.add_row(f"T{i}", t["OutsideIpAddress"], f"[{style}]{status}[/{style}]", t.get("StatusMessage", "")) - + table.add_row( + f"T{i}", + t["OutsideIpAddress"], + f"[{style}]{status}[/{style}]", + t.get("StatusMessage", ""), + ) + console.clear() console.print(table) time.sleep(5) @@ -286,7 +339,9 @@ def cmd_monitor(args): except Exception as e: if "InvalidVpnConnectionID" in str(e): console.print(f"[red]Error: VPN {vpn_id} not found[/red]") - console.print("[yellow]The VPN connection may have been deleted. Use --all to see available VPNs.[/yellow]") + console.print( + "[yellow]The VPN connection may have been deleted. Use --all to see available VPNs.[/yellow]" + ) else: raise @@ -295,19 +350,28 @@ def main(): parser = argparse.ArgumentParser(description="Site-to-Site VPN management CLI") parser.add_argument("-p", "--profile", help="AWS profile name") subparsers = parser.add_subparsers(dest="command", required=True) - + status_parser = subparsers.add_parser("status", help="Show VPN and tunnel status") - status_parser.add_argument("-a", "--all", action="store_true", help="Show all VPN connections instead of just strongSwan") - + status_parser.add_argument( + "-a", + "--all", + action="store_true", + help="Show all VPN connections instead of just strongSwan", + ) + subparsers.add_parser("stop", help="Stop strongSwan instance") subparsers.add_parser("start", help="Start strongSwan instance") subparsers.add_parser("restart", help="Restart strongSwan instance") - - monitor_parser = subparsers.add_parser("monitor", help="Live monitor tunnel status (Ctrl+C to stop)") - monitor_parser.add_argument("-a", "--all", action="store_true", help="Monitor all VPN connections") - + + monitor_parser = subparsers.add_parser( + "monitor", help="Live monitor tunnel status (Ctrl+C to stop)" + ) + monitor_parser.add_argument( + "-a", "--all", action="store_true", help="Monitor all VPN connections" + ) + args = parser.parse_args() - + commands = { "status": cmd_status, "stop": cmd_stop, @@ -315,7 +379,7 @@ def main(): "restart": cmd_restart, "monitor": cmd_monitor, } - + commands[args.command](args) diff --git a/scripts/shell_runner.py b/scripts/shell_runner.py old mode 100644 new mode 100755 index 1b89265..63f164c --- a/scripts/shell_runner.py +++ b/scripts/shell_runner.py @@ -11,7 +11,7 @@ # With AWS profile uv run python scripts/shell_runner.py --profile my-profile "show vpcs" "set vpc 1" "show subnets" - + # With debug logging uv run python scripts/shell_runner.py --debug "show vpcs" "set vpc 1" "show subnets" """ @@ -29,15 +29,17 @@ class ShellRunner: """Run commands against aws-net-shell interactively.""" - def __init__(self, profile: str | None = None, timeout: int = 60, debug: bool = False): + def __init__( + self, profile: str | None = None, timeout: int = 60, debug: bool = False + ): self.profile = profile self.timeout = timeout self.debug = debug self.child: pexpect.spawn | None = None # Match prompt at end of line: "aws-net> " or "context $" - self.prompt_pattern = r'\n(?:aws-net>|.*\$)\s*$' + self.prompt_pattern = r"\n(?:aws-net>|.*\$)\s*$" self.logger = None - + if debug: self._setup_debug_logging() @@ -45,29 +47,29 @@ def _setup_debug_logging(self): """Initialize debug logging to timestamped file in /tmp/.""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = Path(f"/tmp/aws_net_runner_debug_{timestamp}.log") - + # Create logger self.logger = logging.getLogger("shell_runner") self.logger.setLevel(logging.DEBUG) - + # File handler with detailed formatting file_handler = logging.FileHandler(log_file) file_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( "%(asctime)s.%(msecs)03d [%(levelname)s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S" + datefmt="%Y-%m-%d %H:%M:%S", ) file_handler.setFormatter(formatter) self.logger.addHandler(file_handler) - + # Log session start - self.logger.info("="*80) + self.logger.info("=" * 80) self.logger.info("AWS Network Shell Runner - Debug Session Started") self.logger.info(f"Log file: {log_file}") self.logger.info(f"Profile: {self.profile or 'default'}") self.logger.info(f"Timeout: {self.timeout}s") - self.logger.info("="*80) - + self.logger.info("=" * 80) + print(f"[DEBUG] Logging to: {log_file}") def _debug(self, message: str): @@ -87,19 +89,19 @@ def _error(self, message: str): def start(self): """Start the interactive shell.""" - cmd = 'aws-net-shell' + cmd = "aws-net-shell" if self.profile: - cmd += f' --profile {self.profile}' + cmd += f" --profile {self.profile}" self._info(f"Starting shell: {cmd}") - + try: - self.child = pexpect.spawn(cmd, timeout=self.timeout, encoding='utf-8') + self.child = pexpect.spawn(cmd, timeout=self.timeout, encoding="utf-8") self._debug(f"Shell process spawned (PID: {self.child.pid})") - + # Wait for initial prompt self._wait_for_prompt() - print(f"✓ Shell started\n{'='*60}") + print(f"✓ Shell started\n{'=' * 60}") self._info("Shell startup complete") except Exception as e: self._error(f"Shell startup failed: {e}") @@ -108,27 +110,29 @@ def start(self): def _wait_for_prompt(self): """Wait for shell prompt, handling spinners and partial output.""" import time - + self._debug("Waiting for prompt...") start_time = time.time() - + # Collect output until we see a stable prompt buffer = "" last_size = 0 stable_count = 0 iterations = 0 - + while stable_count < 3: # Need 3 stable checks (~0.3s of no new output) iterations += 1 try: # Non-blocking read with short timeout - self.child.expect(r'.+', timeout=0.1) + self.child.expect(r".+", timeout=0.1) buffer += self.child.after - self._debug(f"[iter {iterations}] Read {len(self.child.after)} chars, buffer size: {len(buffer)}") + self._debug( + f"[iter {iterations}] Read {len(self.child.after)} chars, buffer size: {len(buffer)}" + ) except pexpect.TIMEOUT: self._debug(f"[iter {iterations}] Read timeout (no new data)") pass - + # Check if buffer size is stable (no new output) if len(buffer) == last_size: stable_count += 1 @@ -137,19 +141,19 @@ def _wait_for_prompt(self): stable_count = 0 last_size = len(buffer) self._debug(f"[iter {iterations}] Buffer growing, reset stable count") - + # Check for prompt indicators clean = self._strip_ansi(buffer) - if clean.rstrip().endswith(('aws-net>', '$')) and stable_count >= 2: + if clean.rstrip().endswith(("aws-net>", "$")) and stable_count >= 2: self._debug(f"[iter {iterations}] Prompt detected: '{clean[-50:]}'") break - + time.sleep(0.1) - + elapsed = time.time() - start_time self._info(f"Prompt received after {elapsed:.2f}s ({iterations} iterations)") self._debug(f"Final buffer size: {len(buffer)} chars") - + return buffer def run(self, command: str) -> str: @@ -159,23 +163,24 @@ def run(self, command: str) -> str: print(f"\n> {command}") print("-" * 60) - + self._info(f"Executing command: '{command}'") cmd_start = datetime.now() self.child.sendline(command) - self._debug(f"Command sent to shell") - + self._debug("Command sent to shell") + # Wait for complete output import time + time.sleep(0.2) # Let command start self._debug("Waiting for command output...") - + try: output = self._wait_for_prompt() cmd_elapsed = (datetime.now() - cmd_start).total_seconds() self._info(f"Command completed in {cmd_elapsed:.2f}s") - + # Log raw output with ANSI codes self._debug(f"Raw output ({len(output)} chars):\n{output}") except Exception as e: @@ -185,24 +190,24 @@ def run(self, command: str) -> str: # Clean ANSI codes for display but keep structure clean_output = self._strip_ansi(output) self._debug(f"Cleaned output ({len(clean_output)} chars)") - + # Remove the echoed command from output - lines = clean_output.split('\n') + lines = clean_output.split("\n") if lines and command in lines[0]: lines = lines[1:] self._debug("Removed echoed command from output") - - result = '\n'.join(lines).strip() + + result = "\n".join(lines).strip() print(result) return result def run_sequence(self, commands: list[str]): """Run a sequence of commands.""" self._info(f"Running sequence of {len(commands)} commands") - + for idx, cmd in enumerate(commands, 1): cmd = cmd.strip() - if cmd and not cmd.startswith('#'): + if cmd and not cmd.startswith("#"): self._info(f"Command {idx}/{len(commands)}: {cmd}") try: self.run(cmd) @@ -213,21 +218,21 @@ def run_sequence(self, commands: list[str]): def close(self): """Close the shell.""" self._info("Closing shell...") - + if self.child and self.child.isalive(): - self.child.sendline('exit') + self.child.sendline("exit") try: self.child.expect(pexpect.EOF, timeout=5) self._debug("Shell exited cleanly") except Exception as e: self._debug(f"Shell exit timeout, forcing termination: {e}") self.child.terminate(force=True) - - print(f"\n{'='*60}\n✓ Shell closed") - + + print(f"\n{'=' * 60}\n✓ Shell closed") + if self.logger: self._info("Debug session complete") - self._info("="*80) + self._info("=" * 80) # Close all handlers for handler in self.logger.handlers[:]: handler.close() @@ -236,27 +241,33 @@ def close(self): @staticmethod def _strip_ansi(text: str) -> str: """Remove ANSI escape codes.""" - ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') - return ansi_escape.sub('', text) + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", text) def main(): parser = argparse.ArgumentParser( - description='Run commands against aws-net-shell interactively', + description="Run commands against aws-net-shell interactively", formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__ + epilog=__doc__, + ) + parser.add_argument("commands", nargs="*", help="Commands to run") + parser.add_argument("--profile", "-p", help="AWS profile to use") + parser.add_argument( + "--timeout", "-t", type=int, default=30, help="Command timeout (default: 30s)" + ) + parser.add_argument( + "--debug", + "-d", + action="store_true", + help="Enable debug logging to /tmp/aws_net_runner_debug_.log", ) - parser.add_argument('commands', nargs='*', help='Commands to run') - parser.add_argument('--profile', '-p', help='AWS profile to use') - parser.add_argument('--timeout', '-t', type=int, default=30, help='Command timeout (default: 30s)') - parser.add_argument('--debug', '-d', action='store_true', - help='Enable debug logging to /tmp/aws_net_runner_debug_.log') args = parser.parse_args() # Get commands from args or stdin commands = args.commands if not commands and not sys.stdin.isatty(): - commands = sys.stdin.read().strip().split('\n') + commands = sys.stdin.read().strip().split("\n") if not commands: parser.print_help() @@ -275,5 +286,5 @@ def main(): runner.close() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/aws_network_tools/cli/runner.py b/src/aws_network_tools/cli/runner.py old mode 100644 new mode 100755 index b7fe2eb..0a18d72 --- a/src/aws_network_tools/cli/runner.py +++ b/src/aws_network_tools/cli/runner.py @@ -28,64 +28,64 @@ def __init__(self, profile: str | None = None, timeout: int = 60): self.timeout = timeout self.child: pexpect.spawn | None = None # Match prompt at end of line: "aws-net> " or "context $" - self.prompt_pattern = r'\n(?:aws-net>|.*\$)\s*$' + self.prompt_pattern = r"\n(?:aws-net>|.*\$)\s*$" def start(self): """Start the interactive shell.""" - import os - cmd = 'aws-net-shell' + cmd = "aws-net-shell" if self.profile: - cmd += f' --profile {self.profile}' + cmd += f" --profile {self.profile}" # Get actual terminal size or use wide default try: import shutil + cols, rows = shutil.get_terminal_size(fallback=(250, 50)) - except: + except (OSError, ValueError): cols, rows = 250, 50 # Extra wide for full data display self.child = pexpect.spawn( cmd, timeout=self.timeout, - encoding='utf-8', - dimensions=(rows, cols) # Set terminal size for Rich tables + encoding="utf-8", + dimensions=(rows, cols), # Set terminal size for Rich tables ) # Wait for initial prompt self._wait_for_prompt() - print(f"✓ Shell started\n{'='*60}") + print(f"✓ Shell started\n{'=' * 60}") def _wait_for_prompt(self): """Wait for shell prompt, handling spinners and partial output.""" import time - + # Collect output until we see a stable prompt buffer = "" last_size = 0 stable_count = 0 - + while stable_count < 3: # Need 3 stable checks (~0.3s of no new output) try: # Non-blocking read with short timeout - self.child.expect(r'.+', timeout=0.1) + self.child.expect(r".+", timeout=0.1) buffer += self.child.after except pexpect.TIMEOUT: pass - + # Check if buffer size is stable (no new output) if len(buffer) == last_size: stable_count += 1 else: stable_count = 0 last_size = len(buffer) - + # Check for prompt indicators clean = self._strip_ansi(buffer) - if clean.rstrip().endswith(('aws-net>', '$')) and stable_count >= 2: + if clean.rstrip().endswith(("aws-net>", "$")) and stable_count >= 2: break - + time.sleep(0.1) - + return buffer def run(self, command: str) -> str: @@ -97,21 +97,22 @@ def run(self, command: str) -> str: print("-" * 60) self.child.sendline(command) - + # Wait for complete output import time + time.sleep(0.2) # Let command start output = self._wait_for_prompt() # Clean ANSI codes for display but keep structure clean_output = self._strip_ansi(output) - + # Remove the echoed command from output - lines = clean_output.split('\n') + lines = clean_output.split("\n") if lines and command in lines[0]: lines = lines[1:] - - result = '\n'.join(lines).strip() + + result = "\n".join(lines).strip() print(result) return result @@ -119,41 +120,43 @@ def run_sequence(self, commands: list[str]): """Run a sequence of commands.""" for cmd in commands: cmd = cmd.strip() - if cmd and not cmd.startswith('#'): + if cmd and not cmd.startswith("#"): self.run(cmd) def close(self): """Close the shell.""" if self.child and self.child.isalive(): - self.child.sendline('exit') + self.child.sendline("exit") try: self.child.expect(pexpect.EOF, timeout=5) - except: + except (pexpect.TIMEOUT, pexpect.EOF): self.child.terminate(force=True) - print(f"\n{'='*60}\n✓ Shell closed") + print(f"\n{'=' * 60}\n✓ Shell closed") @staticmethod def _strip_ansi(text: str) -> str: """Remove ANSI escape codes.""" - ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') - return ansi_escape.sub('', text) + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", text) def main(): parser = argparse.ArgumentParser( - description='Run commands against aws-net-shell interactively', + description="Run commands against aws-net-shell interactively", formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__ + epilog=__doc__, + ) + parser.add_argument("commands", nargs="*", help="Commands to run") + parser.add_argument("--profile", "-p", help="AWS profile to use") + parser.add_argument( + "--timeout", "-t", type=int, default=30, help="Command timeout (default: 30s)" ) - parser.add_argument('commands', nargs='*', help='Commands to run') - parser.add_argument('--profile', '-p', help='AWS profile to use') - parser.add_argument('--timeout', '-t', type=int, default=30, help='Command timeout (default: 30s)') args = parser.parse_args() # Get commands from args or stdin commands = args.commands if not commands and not sys.stdin.isatty(): - commands = sys.stdin.read().strip().split('\n') + commands = sys.stdin.read().strip().split("\n") if not commands: parser.print_help() @@ -167,5 +170,5 @@ def main(): runner.close() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/aws_network_tools/config/__init__.py b/src/aws_network_tools/config/__init__.py index 3a6745a..b7a66fd 100644 --- a/src/aws_network_tools/config/__init__.py +++ b/src/aws_network_tools/config/__init__.py @@ -8,25 +8,25 @@ class Config: """Configuration manager.""" - + def __init__(self, path: Optional[Path] = None): self.path = path or self._get_default_config_path() self.data = self._load() - + def _get_default_config_path(self) -> Path: """Get default config file path.""" return Path.home() / ".config" / "aws_network_shell" / "config.json" - + def _load(self) -> Dict[str, Any]: """Load config from file.""" if not self.path.exists(): return self._get_defaults() - + try: return json.loads(self.path.read_text()) except Exception: return self._get_defaults() - + def _get_defaults(self) -> Dict[str, Any]: """Get default configuration.""" return { @@ -45,53 +45,53 @@ def _get_defaults(self) -> Dict[str, Any]: "cache": { "enabled": True, "expire_minutes": 30, - } + }, } - + def save(self): """Save configuration to file.""" self.path.parent.mkdir(parents=True, exist_ok=True) self.path.write_text(json.dumps(self.data, indent=2)) - + def get(self, key: str, default: Any = None) -> Any: """Get configuration value using dot notation (e.g., "prompt.style").""" keys = key.split(".") value = self.data - + for k in keys: if isinstance(value, dict) and k in value: value = value[k] else: return default - + return value - + def set(self, key: str, value: Any): """Set configuration value using dot notation.""" keys = key.split(".") target = self.data - + # Navigate to the parent dict for k in keys[:-1]: if k not in target: target[k] = {} target = target[k] - + # Set the final value target[keys[-1]] = value - + def get_prompt_style(self) -> str: """Get prompt style (short or long).""" return self.get("prompt.style", "short") - + def get_theme_name(self) -> str: """Get theme name.""" return self.get("prompt.theme", "catppuccin") - + def show_indices(self) -> bool: """Whether to show indices in prompt.""" return self.get("prompt.show_indices", True) - + def get_max_length(self) -> int: """Get max length for long names.""" return self.get("prompt.max_length", 50) @@ -104,23 +104,23 @@ def get_config() -> Config: class RuntimeConfig: """Thread-safe singleton for runtime configuration. - + Used by modules to access shell runtime settings (profile, regions, no_cache) without explicit parameter passing. - + Usage: # In shell: RuntimeConfig.set_profile("production") RuntimeConfig.set_regions(["us-east-1", "eu-west-1"]) - + # In modules: client = MyClient(RuntimeConfig.get_profile()) regions = RuntimeConfig.get_regions() """ - + _instance = None _lock = Lock() - + def __new__(cls): if cls._instance is None: with cls._lock: @@ -128,7 +128,7 @@ def __new__(cls): cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance - + def __init__(self): if self._initialized: return @@ -137,43 +137,43 @@ def __init__(self): self._no_cache: bool = False self._output_format: str = "table" self._initialized = True - + @classmethod def set_profile(cls, profile: Optional[str]) -> None: """Set AWS profile for all modules.""" instance = cls() instance._profile = profile - + @classmethod def get_profile(cls) -> Optional[str]: """Get current AWS profile.""" instance = cls() return instance._profile - + @classmethod def set_regions(cls, regions: list[str]) -> None: """Set target regions for discovery operations.""" instance = cls() instance._regions = regions if regions else [] - + @classmethod def get_regions(cls) -> list[str]: """Get target regions. Empty list means all regions.""" instance = cls() return instance._regions - + @classmethod def set_no_cache(cls, no_cache: bool) -> None: """Set cache disable flag.""" instance = cls() instance._no_cache = no_cache - + @classmethod def is_cache_disabled(cls) -> bool: """Check if caching is disabled.""" instance = cls() return instance._no_cache - + @classmethod def set_output_format(cls, format: str) -> None: """Set output format (table, json, yaml).""" @@ -181,13 +181,13 @@ def set_output_format(cls, format: str) -> None: if format not in ("table", "json", "yaml"): raise ValueError(f"Invalid format: {format}") instance._output_format = format - + @classmethod def get_output_format(cls) -> str: """Get current output format.""" instance = cls() return instance._output_format - + @classmethod def reset(cls) -> None: """Reset to defaults (mainly for testing).""" @@ -200,4 +200,4 @@ def reset(cls) -> None: def get_runtime_config() -> RuntimeConfig: """Get global runtime config instance.""" - return RuntimeConfig() \ No newline at end of file + return RuntimeConfig() diff --git a/src/aws_network_tools/core/base.py b/src/aws_network_tools/core/base.py index fd03fa9..6c1f2a4 100644 --- a/src/aws_network_tools/core/base.py +++ b/src/aws_network_tools/core/base.py @@ -41,7 +41,7 @@ def __init__( # If no profile provided, use RuntimeConfig if profile is None and session is None: profile = RuntimeConfig.get_profile() - + if session: self.session = session self.profile = profile @@ -70,17 +70,17 @@ def client(self, service: str, region_name: Optional[str] = None): ) # Fallback without custom config return self.session.client(service, region_name=region_name) - + def get_regions(self) -> list[str]: """Get target regions from RuntimeConfig or default to session region. - + Returns: list[str]: Target regions. Empty list if RuntimeConfig has empty regions. """ config_regions = RuntimeConfig.get_regions() if config_regions: return config_regions - + # Fallback to session's default region as single-item list default_region = self.session.region_name return [default_region] if default_region else [] diff --git a/src/aws_network_tools/core/cache_db.py b/src/aws_network_tools/core/cache_db.py index 97db5a6..fd13d50 100644 --- a/src/aws_network_tools/core/cache_db.py +++ b/src/aws_network_tools/core/cache_db.py @@ -6,9 +6,8 @@ import sqlite3 import json -from datetime import datetime, timezone from pathlib import Path -from typing import List, Dict, Any, Optional +from typing import Dict, Any, Optional class CacheDB: @@ -29,7 +28,8 @@ def __init__(self, db_path: Optional[Path] = None): def _init_schema(self): """Create database schema if not exists.""" with sqlite3.connect(self.db_path) as conn: - conn.execute(""" + conn.execute( + """ CREATE TABLE IF NOT EXISTS routing_cache ( id INTEGER PRIMARY KEY AUTOINCREMENT, source TEXT NOT NULL, @@ -45,25 +45,33 @@ def _init_schema(self): cached_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, profile TEXT ) - """) + """ + ) - conn.execute(""" + conn.execute( + """ CREATE INDEX IF NOT EXISTS idx_routing_source ON routing_cache(source) - """) + """ + ) - conn.execute(""" + conn.execute( + """ CREATE INDEX IF NOT EXISTS idx_routing_resource ON routing_cache(resource_id) - """) + """ + ) - conn.execute(""" + conn.execute( + """ CREATE INDEX IF NOT EXISTS idx_routing_destination ON routing_cache(destination) - """) + """ + ) # General topology cache - conn.execute(""" + conn.execute( + """ CREATE TABLE IF NOT EXISTS topology_cache ( id INTEGER PRIMARY KEY AUTOINCREMENT, cache_key TEXT NOT NULL, @@ -72,16 +80,21 @@ def _init_schema(self): profile TEXT, UNIQUE(cache_key, profile) ) - """) + """ + ) - conn.execute(""" + conn.execute( + """ CREATE INDEX IF NOT EXISTS idx_topology_key ON topology_cache(cache_key) - """) + """ + ) conn.commit() - def save_routing_cache(self, cache_data: Dict[str, Any], profile: str = "default") -> int: + def save_routing_cache( + self, cache_data: Dict[str, Any], profile: str = "default" + ) -> int: """Save routing cache to database. Args: @@ -100,29 +113,53 @@ def save_routing_cache(self, cache_data: Dict[str, Any], profile: str = "default routes = data.get("routes", []) for route in routes: - conn.execute(""" + conn.execute( + """ INSERT INTO routing_cache ( source, resource_id, resource_name, region, route_table_id, destination, target, state, type, metadata, profile ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - source, - route.get("vpc_id") or route.get("tgw_id") or route.get("core_network_id"), - route.get("vpc_name") or route.get("tgw_name") or route.get("core_network_name"), - route.get("region"), - route.get("route_table"), - route.get("destination"), - route.get("target"), - route.get("state"), - route.get("type"), - json.dumps({k: v for k, v in route.items() if k not in [ - "vpc_id", "tgw_id", "core_network_id", "vpc_name", "tgw_name", - "core_network_name", "region", "route_table", "destination", - "target", "state", "type", "source" - ]}), - profile - )) + """, + ( + source, + route.get("vpc_id") + or route.get("tgw_id") + or route.get("core_network_id"), + route.get("vpc_name") + or route.get("tgw_name") + or route.get("core_network_name"), + route.get("region"), + route.get("route_table"), + route.get("destination"), + route.get("target"), + route.get("state"), + route.get("type"), + json.dumps( + { + k: v + for k, v in route.items() + if k + not in [ + "vpc_id", + "tgw_id", + "core_network_id", + "vpc_name", + "tgw_name", + "core_network_name", + "region", + "route_table", + "destination", + "target", + "state", + "type", + "source", + ] + } + ), + profile, + ), + ) count += 1 conn.commit() @@ -139,13 +176,20 @@ def load_routing_cache(self, profile: str = "default") -> Dict[str, Any]: """ with sqlite3.connect(self.db_path) as conn: conn.row_factory = sqlite3.Row - cursor = conn.execute(""" + cursor = conn.execute( + """ SELECT * FROM routing_cache WHERE profile = ? ORDER BY source, resource_id, route_table_id - """, (profile,)) - - routes_by_source = {"vpc": {"routes": []}, "tgw": {"routes": []}, "cloudwan": {"routes": []}} + """, + (profile,), + ) + + routes_by_source = { + "vpc": {"routes": []}, + "tgw": {"routes": []}, + "cloudwan": {"routes": []}, + } for row in cursor: route = { @@ -187,13 +231,18 @@ def save_topology_cache(self, cache_key: str, data: Any, profile: str = "default profile: AWS profile name """ with sqlite3.connect(self.db_path) as conn: - conn.execute(""" + conn.execute( + """ INSERT OR REPLACE INTO topology_cache (cache_key, cache_data, profile, cached_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP) - """, (cache_key, json.dumps(data), profile)) + """, + (cache_key, json.dumps(data), profile), + ) conn.commit() - def load_topology_cache(self, cache_key: str, profile: str = "default") -> Optional[Any]: + def load_topology_cache( + self, cache_key: str, profile: str = "default" + ) -> Optional[Any]: """Load topology cache entry. Args: @@ -204,10 +253,13 @@ def load_topology_cache(self, cache_key: str, profile: str = "default") -> Optio Cached data or None if not found/expired """ with sqlite3.connect(self.db_path) as conn: - cursor = conn.execute(""" + cursor = conn.execute( + """ SELECT cache_data, cached_at FROM topology_cache WHERE cache_key = ? AND profile = ? - """, (cache_key, profile)) + """, + (cache_key, profile), + ) row = cursor.fetchone() if not row: @@ -233,13 +285,15 @@ def clear_all(self, profile: Optional[str] = None): def get_stats(self) -> Dict[str, Any]: """Get cache statistics.""" with sqlite3.connect(self.db_path) as conn: - cursor = conn.execute(""" + cursor = conn.execute( + """ SELECT COUNT(*) as total_routes, COUNT(DISTINCT profile) as profiles, COUNT(DISTINCT source) as sources FROM routing_cache - """) + """ + ) routing_stats = dict(cursor.fetchone()) cursor = conn.execute("SELECT COUNT(*) FROM topology_cache") @@ -248,5 +302,7 @@ def get_stats(self) -> Dict[str, Any]: return { "routing_cache": routing_stats, "topology_cache": {"entries": topology_count}, - "db_size_bytes": self.db_path.stat().st_size if self.db_path.exists() else 0 + "db_size_bytes": self.db_path.stat().st_size + if self.db_path.exists() + else 0, } diff --git a/src/aws_network_tools/core/validators.py b/src/aws_network_tools/core/validators.py index 9e0e76f..c184b5e 100644 --- a/src/aws_network_tools/core/validators.py +++ b/src/aws_network_tools/core/validators.py @@ -12,28 +12,50 @@ # Known AWS regions (as of 2025) VALID_AWS_REGIONS = { # US regions - "us-east-1", "us-east-2", "us-west-1", "us-west-2", + "us-east-1", + "us-east-2", + "us-west-1", + "us-west-2", # Europe regions - "eu-west-1", "eu-west-2", "eu-west-3", "eu-central-1", - "eu-central-2", "eu-north-1", "eu-south-1", "eu-south-2", + "eu-west-1", + "eu-west-2", + "eu-west-3", + "eu-central-1", + "eu-central-2", + "eu-north-1", + "eu-south-1", + "eu-south-2", # Asia Pacific - "ap-south-1", "ap-south-2", "ap-northeast-1", "ap-northeast-2", - "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", - "ap-southeast-3", "ap-southeast-4", "ap-east-1", + "ap-south-1", + "ap-south-2", + "ap-northeast-1", + "ap-northeast-2", + "ap-northeast-3", + "ap-southeast-1", + "ap-southeast-2", + "ap-southeast-3", + "ap-southeast-4", + "ap-east-1", # Other regions - "ca-central-1", "ca-west-1", + "ca-central-1", + "ca-west-1", "sa-east-1", - "me-south-1", "me-central-1", + "me-south-1", + "me-central-1", "af-south-1", "il-central-1", # GovCloud - "us-gov-east-1", "us-gov-west-1", + "us-gov-east-1", + "us-gov-west-1", # China (special) - "cn-north-1", "cn-northwest-1", + "cn-north-1", + "cn-northwest-1", } -def validate_regions(region_input: str) -> Tuple[bool, Optional[List[str]], Optional[str]]: +def validate_regions( + region_input: str, +) -> Tuple[bool, Optional[List[str]], Optional[str]]: """Validate region input string. Args: @@ -50,10 +72,14 @@ def validate_regions(region_input: str) -> Tuple[bool, Optional[List[str]], Opti # Check for space-separated (common mistake) if " " in region_input and "," not in region_input: - return False, None, ( - "Regions must be comma-separated, not space-separated.\n" - " ✗ Wrong: eu-west-1 eu-west-2\n" - " ✓ Right: eu-west-1,eu-west-2" + return ( + False, + None, + ( + "Regions must be comma-separated, not space-separated.\n" + " ✗ Wrong: eu-west-1 eu-west-2\n" + " ✓ Right: eu-west-1,eu-west-2" + ), ) # Parse comma-separated regions @@ -77,10 +103,10 @@ def validate_regions(region_input: str) -> Tuple[bool, Optional[List[str]], Opti suggestions = _suggest_regions(invalid_regions) error_msg = f"Invalid region codes: {', '.join(invalid_regions)}\n" if suggestions: - error_msg += f"\nDid you mean?\n" + error_msg += "\nDid you mean?\n" for invalid, suggested in suggestions.items(): error_msg += f" {invalid} → {', '.join(suggested)}\n" - error_msg += f"\nValid examples: us-east-1, eu-west-1, ap-southeast-1" + error_msg += "\nValid examples: us-east-1, eu-west-1, ap-southeast-1" return False, None, error_msg return True, regions, None @@ -130,15 +156,21 @@ def validate_profile(profile_input: str) -> Tuple[bool, Optional[str], Optional[ # Check for invalid characters (AWS profile names are alphanumeric + _-) if not re.match(r"^[a-zA-Z0-9_-]+$", profile): - return False, None, ( - f"Invalid profile name: '{profile}'\n" - "Profile names must contain only letters, numbers, hyphens, and underscores" + return ( + False, + None, + ( + f"Invalid profile name: '{profile}'\n" + "Profile names must contain only letters, numbers, hyphens, and underscores" + ), ) return True, profile, None -def validate_output_format(format_input: str) -> Tuple[bool, Optional[str], Optional[str]]: +def validate_output_format( + format_input: str, +) -> Tuple[bool, Optional[str], Optional[str]]: """Validate output format. Args: @@ -154,9 +186,13 @@ def validate_output_format(format_input: str) -> Tuple[bool, Optional[str], Opti valid_formats = {"table", "json", "yaml"} if fmt not in valid_formats: - return False, None, ( - f"Invalid format: '{fmt}'\n" - f"Valid formats: {', '.join(sorted(valid_formats))}" + return ( + False, + None, + ( + f"Invalid format: '{fmt}'\n" + f"Valid formats: {', '.join(sorted(valid_formats))}" + ), ) return True, fmt, None diff --git a/src/aws_network_tools/modules/anfw.py b/src/aws_network_tools/modules/anfw.py index 033b630..cbd8609 100644 --- a/src/aws_network_tools/modules/anfw.py +++ b/src/aws_network_tools/modules/anfw.py @@ -36,8 +36,16 @@ def context_commands(self) -> Dict[str, List[str]]: def show_commands(self) -> Dict[str, List[str]]: return { None: ["firewalls"], - "aws-network-firewall": ["firewall", "detail", "firewall-rule-groups", "rule-groups", "firewall-policy", "policy", "firewall-networking"], - "rule-group": ["rule-group"] + "aws-network-firewall": [ + "firewall", + "detail", + "firewall-rule-groups", + "rule-groups", + "firewall-policy", + "policy", + "firewall-networking", + ], + "rule-group": ["rule-group"], } def execute(self, shell, command: str, args: str): @@ -193,7 +201,7 @@ def _get_rule_group(self, client, rg_name: str, rg_type: str) -> dict: protocols = match_attrs.get("Protocols", []) source_ports = match_attrs.get("SourcePorts", []) dest_ports = match_attrs.get("DestinationPorts", []) - + rules.append( { "priority": sr.get("Priority", 0), diff --git a/src/aws_network_tools/modules/cloudwan.py b/src/aws_network_tools/modules/cloudwan.py index a468d48..ff6e03a 100644 --- a/src/aws_network_tools/modules/cloudwan.py +++ b/src/aws_network_tools/modules/cloudwan.py @@ -230,6 +230,7 @@ def get_policy_change_events(self, cn_id: str, max_results: int = 50) -> list[di # Sort by created_at, handling mixed datetime/None values from datetime import datetime, timezone + def sort_key(x): val = x.get("created_at") if val is None: @@ -240,6 +241,7 @@ def sort_key(x): return val.replace(tzinfo=timezone.utc) return val return datetime.min.replace(tzinfo=timezone.utc) + return sorted(events, key=sort_key, reverse=True) def get_policy_document( diff --git a/src/aws_network_tools/modules/elb.py b/src/aws_network_tools/modules/elb.py index 0b75e9e..3af5636 100644 --- a/src/aws_network_tools/modules/elb.py +++ b/src/aws_network_tools/modules/elb.py @@ -144,12 +144,13 @@ def get_listeners(self, elb_arn: str, region: str) -> list[dict]: Returns: List of listener dictionaries """ - client = self.session.client('elbv2', region_name=region) + client = self.session.client("elbv2", region_name=region) try: resp = client.describe_listeners(LoadBalancerArn=elb_arn) - return resp.get('Listeners', []) + return resp.get("Listeners", []) except Exception as e: import logging + logging.warning(f"Failed to get listeners for {elb_arn}: {e}") return [] @@ -163,7 +164,7 @@ def get_target_groups(self, elb_arn: str, region: str) -> list[dict]: Returns: List of target group dictionaries """ - client = self.session.client('elbv2', region_name=region) + client = self.session.client("elbv2", region_name=region) try: # Get listeners first to find target groups listeners = self.get_listeners(elb_arn, region) @@ -171,17 +172,18 @@ def get_target_groups(self, elb_arn: str, region: str) -> list[dict]: # Collect unique target group ARNs tg_arns = set() for listener in listeners: - for action in listener.get('DefaultActions', []): - if action.get('TargetGroupArn'): - tg_arns.add(action['TargetGroupArn']) + for action in listener.get("DefaultActions", []): + if action.get("TargetGroupArn"): + tg_arns.add(action["TargetGroupArn"]) # Fetch target group details if tg_arns: resp = client.describe_target_groups(TargetGroupArns=list(tg_arns)) - return resp.get('TargetGroups', []) + return resp.get("TargetGroups", []) return [] except Exception as e: import logging + logging.warning(f"Failed to get target groups for {elb_arn}: {e}") return [] @@ -195,15 +197,16 @@ def get_target_health(self, tg_arns: list[str], region: str) -> dict: Returns: Dict mapping target group ARN to list of health descriptions """ - client = self.session.client('elbv2', region_name=region) + client = self.session.client("elbv2", region_name=region) health_status = {} for tg_arn in tg_arns: try: resp = client.describe_target_health(TargetGroupArn=tg_arn) - health_status[tg_arn] = resp.get('TargetHealthDescriptions', []) + health_status[tg_arn] = resp.get("TargetHealthDescriptions", []) except Exception as e: import logging + logging.warning(f"Failed to get health for {tg_arn}: {e}") health_status[tg_arn] = [] @@ -243,7 +246,7 @@ def get_elb_detail(self, elb_arn: str, region: str) -> dict: if not listeners: # Still continue to check target groups even if no listeners pass - + for listener in listeners: listener_arn = listener["ListenerArn"] original_default_actions = listener.get("DefaultActions", []) @@ -275,19 +278,19 @@ def get_elb_detail(self, elb_arn: str, region: str) -> dict: detail["target_groups"].append(tg_detail) # Aggregate target_health from each target group for target in tg_detail.get("targets", []): - detail["target_health"].append({ - "target_group_arn": act["target_group_arn"], - "target_group_name": tg_detail.get("name"), - **target, - }) + detail["target_health"].append( + { + "target_group_arn": act["target_group_arn"], + "target_group_name": tg_detail.get("name"), + **target, + } + ) listener_data["default_actions"].append(act) # If ALB, get rules - use listener_arn (not shadowed variable) if lb["Type"] == "application": - rules_resp = client.describe_rules( - ListenerArn=listener_arn - ) + rules_resp = client.describe_rules(ListenerArn=listener_arn) for r in rules_resp.get("Rules", []): if r["IsDefault"]: continue # Skip default rule as it's covered in DefaultActions usually @@ -313,11 +316,17 @@ def get_elb_detail(self, elb_arn: str, region: str) -> dict: seen_target_groups.add(act["target_group_arn"]) detail["target_groups"].append(tg_detail) for target in tg_detail.get("targets", []): - detail["target_health"].append({ - "target_group_arn": act["target_group_arn"], - "target_group_name": tg_detail.get("name"), - **target, - }) + detail["target_health"].append( + { + "target_group_arn": act[ + "target_group_arn" + ], + "target_group_name": tg_detail.get( + "name" + ), + **target, + } + ) rule["actions"].append(act) listener_data["rules"].append(rule) @@ -326,13 +335,16 @@ def get_elb_detail(self, elb_arn: str, region: str) -> dict: except Exception as e: # Issue #10: Log error but don't abort - continue to return what we have import logging - logging.getLogger(__name__).warning(f"Error fetching listeners for {elb_arn}: {e}") + + logging.getLogger(__name__).warning( + f"Error fetching listeners for {elb_arn}: {e}" + ) # Issue #10: If no listeners found, still try to get target groups from ARN patterns # Some load balancers have target groups but listeners aren't attached if not detail["listeners"] and detail["target_groups"]: pass # We already have target groups from above or will get them below - + return detail def _get_target_group_detail(self, client, tg_arn: str) -> dict: diff --git a/src/aws_network_tools/modules/vpc.py b/src/aws_network_tools/modules/vpc.py index a798f75..6ae1d1b 100644 --- a/src/aws_network_tools/modules/vpc.py +++ b/src/aws_network_tools/modules/vpc.py @@ -262,22 +262,23 @@ def get_vpc_detail(self, vpc_id: str, region: str) -> dict: sgs = [] for sg in sg_resp.get("SecurityGroups", []): ingress, egress = [], [] - for r in sg.get("IpPermissions", []): - proto = r.get("IpProtocol", "all") + # Process ingress rules + for rule in sg.get("IpPermissions", []): + proto = rule.get("IpProtocol", "all") ports = ( - f"{r.get('FromPort', 'all')}-{r.get('ToPort', 'all')}" - if r.get("FromPort") + f"{rule.get('FromPort', 'all')}-{rule.get('ToPort', 'all')}" + if rule.get("FromPort") else "all" ) - for ip in r.get("IpRanges", []): + for ip_range in rule.get("IpRanges", []): ingress.append( { "protocol": proto, "ports": ports, - "source": ip.get("CidrIp", "N/A"), + "source": ip_range.get("CidrIp", "N/A"), } ) - for grp in r.get("UserIdGroupPairs", []): + for grp in rule.get("UserIdGroupPairs", []): ingress.append( { "protocol": proto, @@ -285,43 +286,43 @@ def get_vpc_detail(self, vpc_id: str, region: str) -> dict: "source": grp.get("GroupId", "N/A"), } ) - if not r.get("IpRanges") and not r.get("UserIdGroupPairs"): + if not rule.get("IpRanges") and not rule.get("UserIdGroupPairs"): ingress.append({"protocol": proto, "ports": ports, "source": "N/A"}) - for r in sg.get("IpPermissionsEgress", []): - proto = r.get("IpProtocol", "all") - ports = ( - f"{r.get('FromPort', 'all')}-{r.get('ToPort', 'all')}" - if r.get("FromPort") - else "all" + # Process egress rules (separate loop, not nested) + for egress_rule in sg.get("IpPermissionsEgress", []): + proto = egress_rule.get("IpProtocol", "all") + ports = ( + f"{egress_rule.get('FromPort', 'all')}-{egress_rule.get('ToPort', 'all')}" + if egress_rule.get("FromPort") + else "all" + ) + for ip_range in egress_rule.get("IpRanges", []): + egress.append( + { + "protocol": proto, + "ports": ports, + "dest": ip_range.get("CidrIp") or "0.0.0.0/0", + } + ) + # IPv6 ranges + for ip6 in egress_rule.get("Ipv6Ranges", []): + egress.append( + { + "protocol": proto, + "ports": ports, + "dest": ip6.get("CidrIpv6") or "::/0", + } + ) + if not egress_rule.get("IpRanges") and not egress_rule.get( + "Ipv6Ranges" + ): + egress.append( + { + "protocol": proto, + "ports": ports, + "dest": "0.0.0.0/0, ::/0", + } ) - for ip in r.get("IpRanges", []): - egress.append( - { - "protocol": proto, - "ports": ports, - "dest": ip.get("CidrIp") - or (".".join(map(str, (0, 0, 0, 0))) + "/0"), - } - ) - # IPv6 ranges - for ip6 in r.get("Ipv6Ranges", []): - egress.append( - { - "protocol": proto, - "ports": ports, - "dest": ip6.get("CidrIpv6") or ("::" + "/0"), - } - ) - if not r.get("IpRanges") and not r.get("Ipv6Ranges"): - egress.append( - { - "protocol": proto, - "ports": ports, - "dest": (".".join(map(str, (0, 0, 0, 0))) + "/0") - + ", " - + ("::" + "/0"), - } - ) sgs.append( { "id": sg["GroupId"], diff --git a/src/aws_network_tools/shell/base.py b/src/aws_network_tools/shell/base.py index 9e48245..788a858 100644 --- a/src/aws_network_tools/shell/base.py +++ b/src/aws_network_tools/shell/base.py @@ -5,7 +5,7 @@ from rich.console import Console from rich.text import Text from dataclasses import dataclass, field -from ..themes import load_theme, get_theme_dir +from ..themes import load_theme from ..config import get_config, RuntimeConfig console = Console() @@ -112,12 +112,27 @@ "rib", ], "set": ["route-table"], - "commands": ["show", "set", "find_prefix", "find_null_routes", "refresh", "exit", "end"], + "commands": [ + "show", + "set", + "find_prefix", + "find_null_routes", + "refresh", + "exit", + "end", + ], }, "route-table": { "show": ["routes"], "set": [], - "commands": ["show", "find_prefix", "find_null_routes", "refresh", "exit", "end"], + "commands": [ + "show", + "find_prefix", + "find_null_routes", + "refresh", + "exit", + "end", + ], }, "vpc": { "show": [ @@ -131,15 +146,39 @@ "endpoints", ], "set": ["route-table"], - "commands": ["show", "set", "find_prefix", "find_null_routes", "refresh", "exit", "end"], + "commands": [ + "show", + "set", + "find_prefix", + "find_null_routes", + "refresh", + "exit", + "end", + ], }, "transit-gateway": { "show": ["detail", "route-tables", "attachments"], "set": ["route-table"], - "commands": ["show", "set", "find_prefix", "find_null_routes", "refresh", "exit", "end"], + "commands": [ + "show", + "set", + "find_prefix", + "find_null_routes", + "refresh", + "exit", + "end", + ], }, "firewall": { - "show": ["firewall", "detail", "firewall-rule-groups", "rule-groups", "policy", "firewall-policy", "firewall-networking"], + "show": [ + "firewall", + "detail", + "firewall-rule-groups", + "rule-groups", + "policy", + "firewall-policy", + "firewall-networking", + ], "set": ["rule-group"], "commands": ["show", "set", "refresh", "exit", "end"], }, @@ -191,12 +230,12 @@ def __init__(self): self.watch_interval: int = 0 self.context_stack: list[Context] = [] self._cache: dict = {} - + # Load theme and config self.config = get_config() theme_name = self.config.get_theme_name() self.theme = load_theme(theme_name) - + # Initialize RuntimeConfig with shell defaults RuntimeConfig.set_profile(self.profile) RuntimeConfig.set_regions(self.regions) @@ -224,7 +263,7 @@ def __init__(self): ] ) self._update_prompt() - + def _sync_runtime_config(self): """Synchronize shell state with RuntimeConfig singleton.""" RuntimeConfig.set_profile(self.profile) @@ -253,18 +292,18 @@ def _update_prompt(self): if not self.context_stack: self.prompt = "aws-net> " return - + # Get prompt configuration style = self.config.get_prompt_style() # "short" or "long" show_indices = self.config.show_indices() max_length = self.config.get_max_length() - + prompt_parts = [] - + for i, ctx in enumerate(self.context_stack): # Get color for this context type color = self.theme.get(ctx.type, "white") - + # Get abbreviation for context type if ctx.type == "global-network": abbrev = "gl" @@ -276,7 +315,7 @@ def _update_prompt(self): abbrev = "ec" else: abbrev = ctx.type[:2] - + if style == "short": # Short format: use index number like gl:1, cn:1 ctx_name = f"{abbrev}:{ctx.selection_index}" @@ -291,35 +330,34 @@ def _update_prompt(self): # Long format without indices: gl:name display_name = ctx.name or ctx.ref if len(display_name) > max_length: - display_name = display_name[:max_length-3] + "..." + display_name = display_name[: max_length - 3] + "..." ctx_name = f"{abbrev}:{display_name}" - + # Create colored text part (no newlines embedded) - from rich.text import Text colored_part = Text(f"{ctx_name}", style=color) prompt_parts.append(colored_part) - + # Create the full prompt if style == "long": # Multi-line prompt with continuation markers prompt_text = Text("aws-net> ", style=self.theme.get("prompt_text")) separator_style = self.theme.get("prompt_separator") - + if prompt_parts: # First context on same line as aws-net> prompt_text.append(prompt_parts[0]) - + if len(prompt_parts) > 1: # Multiple contexts - use multi-line format prompt_text.append(Text(" >\n", style=separator_style)) - + # Middle contexts (all except first and last) for i, part in enumerate(prompt_parts[1:-1], 1): indent = " " * i # Two spaces per level prompt_text.append(Text(f"{indent}", style=separator_style)) prompt_text.append(part) prompt_text.append(Text(" >\n", style=separator_style)) - + # Last context last_idx = len(prompt_parts) - 1 indent = " " * last_idx @@ -338,20 +376,30 @@ def _update_prompt(self): prompt_text = Text("aws-net", style=self.theme.get("prompt_text")) if prompt_parts: for part in prompt_parts: - prompt_text.append(f">", style=separator_color) + prompt_text.append(">", style=separator_color) prompt_text.append(part) prompt_text.append("> ", style=separator_color) - + # Render Text to ANSI codes for cmd2 from rich.console import Console + render_console = Console(force_terminal=True, color_system="standard") with render_console.capture() as capture: render_console.print(prompt_text, end="") self.prompt = capture.get() - def _enter(self, ctx_type: str, res_id: str, name: str, data: dict = None, selection_index: int = 1): + def _enter( + self, + ctx_type: str, + res_id: str, + name: str, + data: dict = None, + selection_index: int = 1, + ): """Enter a new context.""" - self.context_stack.append(Context(ctx_type, res_id, name, data or {}, selection_index)) + self.context_stack.append( + Context(ctx_type, res_id, name, data or {}, selection_index) + ) self._update_prompt() def _resolve(self, items: list, val: str) -> Optional[dict]: @@ -421,14 +469,14 @@ def do_clear_cache(self, _): def do_refresh(self, args): """Refresh cached data. Usage: refresh [target|all] - + Examples: refresh - Refresh current context data refresh transit_gateways - Clear transit gateways cache refresh all - Clear all caches """ target = str(args).strip().lower().replace("-", "_") if args else "" - + if not target or target == "current": # Refresh current context by clearing relevant cache keys context_to_cache = { @@ -442,24 +490,24 @@ def do_refresh(self, args): "core-network": "core_networks", "route-table": "core_networks", # Route tables belong to core networks } - + cache_key = context_to_cache.get(self.ctx_type) if not cache_key: console.print("[yellow]No cache to refresh in current context[/]") return - + if cache_key in self._cache: del self._cache[cache_key] console.print(f"[green]Refreshed {cache_key} cache[/]") else: console.print(f"[yellow]No cached data for {cache_key}[/]") - + elif target == "all": # Clear entire cache count = len(self._cache) self._cache.clear() console.print(f"[green]Cleared {count} cache entries[/]") - + else: # Clear specific cache key with alias support cache_aliases = { @@ -482,7 +530,6 @@ def do_refresh(self, args): "client_vpn_endpoints": "client_vpn_endpoints", "global_accelerators": "global_accelerators", "vpc_endpoints": "vpc_endpoints", - # Singular aliases "vpc": "vpcs", "transit_gateway": "transit_gateways", @@ -501,7 +548,6 @@ def do_refresh(self, args): "client_vpn_endpoint": "client_vpn_endpoints", "global_accelerator": "global_accelerators", "vpc_endpoint": "vpc_endpoints", - # Common abbreviations "tgw": "transit_gateways", "tgws": "transit_gateways", @@ -512,9 +558,9 @@ def do_refresh(self, args): "ga": "global_accelerators", "ec2": "ec2_instances", } - + cache_key = cache_aliases.get(target, target) - + if cache_key in self._cache: del self._cache[cache_key] console.print(f"[green]Refreshed {cache_key} cache[/]") diff --git a/src/aws_network_tools/shell/graph.py b/src/aws_network_tools/shell/graph.py index 984f520..228e27c 100644 --- a/src/aws_network_tools/shell/graph.py +++ b/src/aws_network_tools/shell/graph.py @@ -603,7 +603,7 @@ def stats(self) -> dict: def find_command_path(self, command: str) -> list[dict]: """Find all paths to reach a command. - + Returns list of dicts with: - path: list of commands to reach the target - context: the context where command is available @@ -611,44 +611,48 @@ def find_command_path(self, command: str) -> list[dict]: """ results = [] command_lower = command.lower().strip() - + # Search all nodes for matching commands for node_id, node in self.nodes.items(): node_name_lower = node.name.lower() - + # Match by full name or partial if command_lower == node_name_lower or command_lower in node_name_lower: path_info = self._build_path_to_node(node) if path_info: results.append(path_info) - + return results def _build_path_to_node(self, target_node: CommandNode) -> Optional[dict]: """Build the path from root to a target node.""" if target_node.node_type == NodeType.ROOT: return None - + # Find path by traversing from root path = [] - + def find_path(node: CommandNode, current_path: list) -> bool: if node.id == target_node.id: path.extend(current_path + [node.name]) return True for child in node.children: - if find_path(child, current_path + ([node.name] if node.node_type != NodeType.ROOT else [])): + if find_path( + child, + current_path + + ([node.name] if node.node_type != NodeType.ROOT else []), + ): return True return False - + find_path(self.root, []) - + if not path: return None - + # Determine if global (root level) or requires context navigation is_global = target_node.context is None - + # Build prerequisite show command for context-entering sets prereq_show = None if not is_global and len(path) > 1: @@ -659,7 +663,7 @@ def find_path(node: CommandNode, current_path: list) -> bool: # Map to corresponding show command show_map = { "vpc": "show vpcs", - "transit-gateway": "show transit_gateways", + "transit-gateway": "show transit_gateways", "global-network": "show global-networks", "core-network": "show core-networks", "firewall": "show firewalls", @@ -669,7 +673,7 @@ def find_path(node: CommandNode, current_path: list) -> bool: "route-table": "show route-tables", } prereq_show = show_map.get(resource) - + return { "command": target_node.name, "path": path, diff --git a/src/aws_network_tools/shell/handlers/cloudwan.py b/src/aws_network_tools/shell/handlers/cloudwan.py index 894a548..045d137 100644 --- a/src/aws_network_tools/shell/handlers/cloudwan.py +++ b/src/aws_network_tools/shell/handlers/cloudwan.py @@ -154,10 +154,11 @@ def do_show(self, args): console.print( "[red]Usage: show policy document-diff [/]" ) - console.print("[dim]Use 'show policy-documents' to see available versions[/]") + console.print( + "[dim]Use 'show policy-documents' to see available versions[/]" + ) return # Fall back to default handler - from ...shell.main import AWSNetShell super(CloudWANHandlersMixin, self).do_show(args) def _show_policy_document_diff(self, version1: str, version2: str): @@ -177,11 +178,8 @@ def _show_policy_document_diff(self, version1: str, version2: str): def fetch_doc(version): try: - resp = ( - cloudwan.CloudWANClient(self.profile) - .nm.get_core_network_policy( - CoreNetworkId=cn_id, PolicyVersionId=version - ) + resp = cloudwan.CloudWANClient(self.profile).nm.get_core_network_policy( + CoreNetworkId=cn_id, PolicyVersionId=version ) return json.loads(resp["CoreNetworkPolicy"]["PolicyDocument"]) except Exception as e: @@ -207,7 +205,11 @@ def fetch_doc(version): diff = list( difflib.unified_diff( - doc1_str, doc2_str, lineterm="", fromfile=f"Version {v1}", tofile=f"Version {v2}" + doc1_str, + doc2_str, + lineterm="", + fromfile=f"Version {v1}", + tofile=f"Version {v2}", ) ) @@ -407,7 +409,9 @@ def _show_routes(self, _): ) console.print(table) console.print() - console.print(f"[dim]Total: {total_routes} routes across {len(rts)} route tables[/]") + console.print( + f"[dim]Total: {total_routes} routes across {len(rts)} route tables[/]" + ) else: console.print("[red]Must be in core-network or route-table context[/]") diff --git a/src/aws_network_tools/shell/handlers/ec2.py b/src/aws_network_tools/shell/handlers/ec2.py index 8b9ef4c..e78ee8a 100644 --- a/src/aws_network_tools/shell/handlers/ec2.py +++ b/src/aws_network_tools/shell/handlers/ec2.py @@ -54,13 +54,17 @@ def _set_ec2_instance(self, val): instances = self._cache.get(key, []) # If cache empty AND val looks like instance ID, fetch directly - if not instances and val.startswith('i-'): + if not instances and val.startswith("i-"): from ...modules.ec2 import EC2Client from ...core import run_with_spinner # Try to fetch this specific instance from all regions target_instance = None - regions_to_try = self.regions if self.regions else ['us-east-1', 'eu-west-1', 'ap-southeast-2'] + regions_to_try = ( + self.regions + if self.regions + else ["us-east-1", "eu-west-1", "ap-southeast-2"] + ) for region in regions_to_try: try: @@ -69,10 +73,10 @@ def _set_ec2_instance(self, val): target_instance = { "id": val, "name": detail.get("name", val), - "region": region + "region": region, } break - except: + except Exception: continue if not target_instance: @@ -89,9 +93,11 @@ def _set_ec2_instance(self, val): ) self._enter( - "ec2-instance", target_instance["id"], + "ec2-instance", + target_instance["id"], target_instance.get("name") or target_instance["id"], - detail, 1 + detail, + 1, ) print() return @@ -120,7 +126,11 @@ def _set_ec2_instance(self, val): except ValueError: selection_idx = 1 self._enter( - "ec2-instance", target["id"], target.get("name") or target["id"], detail, selection_idx + "ec2-instance", + target["id"], + target.get("name") or target["id"], + detail, + selection_idx, ) print() # Add blank line before next prompt diff --git a/src/aws_network_tools/shell/handlers/firewall.py b/src/aws_network_tools/shell/handlers/firewall.py index 5a5c1a5..2a7c287 100644 --- a/src/aws_network_tools/shell/handlers/firewall.py +++ b/src/aws_network_tools/shell/handlers/firewall.py @@ -25,7 +25,9 @@ def _set_firewall(self, val): selection_idx = int(val) except ValueError: selection_idx = 1 - self._enter("firewall", fw.get("arn", ""), fw.get("name", ""), fw, selection_idx) + self._enter( + "firewall", fw.get("arn", ""), fw.get("name", ""), fw, selection_idx + ) print() # Add blank line before next prompt def _show_firewall(self, _): @@ -33,6 +35,7 @@ def _show_firewall(self, _): if self.ctx_type != "firewall": return from ...modules.anfw import ANFWDisplay + ANFWDisplay(console).show_firewall_detail(self.ctx.data) # Alias for backward compatibility @@ -60,7 +63,7 @@ def _show_firewall_rule_groups(self, _): rg.get("name", ""), rg.get("type", ""), str(len(rg.get("rules", []))), - f"{rg.get('consumed_capacity', 0)}/{rg.get('capacity', 0)}" + f"{rg.get('consumed_capacity', 0)}/{rg.get('capacity', 0)}", ) console.print(table) console.print("[dim]Use 'set rule-group <#>' to select[/]") @@ -78,61 +81,59 @@ def _set_rule_group(self, val): if not val: console.print("[red]Usage: set rule-group <#>[/]") return - + rgs = self.ctx.data.get("rule_groups", []) if not rgs: console.print("[yellow]No rule groups available[/]") return - + # Resolve by index or name rg = self._resolve(rgs, val) if not rg: console.print(f"[red]Rule group not found: {val}[/]") return - + try: selection_idx = int(val) except ValueError: selection_idx = 1 - - self._enter("rule-group", rg.get("name", ""), rg.get("name", ""), rg, selection_idx) + + self._enter( + "rule-group", rg.get("name", ""), rg.get("name", ""), rg, selection_idx + ) print() def _show_rule_group(self, _): """Show detailed rule group information.""" if self.ctx_type != "rule-group": return - + rg = self.ctx.data from rich.panel import Panel - from rich.text import Text - + console.print( - Panel( - f"[bold]{rg['name']}[/] ({rg['type']})", - title="Rule Group" - ) + Panel(f"[bold]{rg['name']}[/] ({rg['type']})", title="Rule Group") ) - + if rg.get("error"): console.print(f"[red]Error: {rg['error']}[/]") return - + cap_info = f"[dim]Capacity: {rg.get('consumed_capacity', 0)}/{rg.get('capacity', 0)}[/]" console.print(cap_info) console.print() - + if rg["type"] == "STATEFUL": # Stateful rules (Suricata format, domain lists, or 5-tuple) rules = rg.get("rules", []) if not rules: console.print("[dim]No rules found[/]") return - + table = Table(show_header=True, header_style="bold") table.add_column("#", style="dim", justify="right") table.add_column("Rule", style="cyan") - + for i, rule in enumerate(rules, 1): if "rule" in rule: console.print(f" [dim]{i}.[/] [cyan]{rule['rule']}[/]") @@ -142,7 +143,7 @@ def _show_rule_group(self, _): if not rules: console.print("[dim]No rules found[/]") return - + table = Table(show_header=True, header_style="bold") table.add_column("#", style="dim", justify="right") table.add_column("Priority", style="yellow", justify="right") @@ -152,24 +153,38 @@ def _show_rule_group(self, _): table.add_column("Protocols", style="white") table.add_column("Source Ports", style="dim") table.add_column("Dest Ports", style="dim") - + for i, rule in enumerate(rules, 1): # Format source/dest ports src_ports = rule.get("source_ports", []) dst_ports = rule.get("dest_ports", []) - - src_port_str = ", ".join( - f"{p.get('FromPort', '')}-{p.get('ToPort', '')}" if p.get('FromPort') != p.get('ToPort') - else str(p.get('FromPort', '')) - for p in src_ports - ) if src_ports else "Any" - - dst_port_str = ", ".join( - f"{p.get('FromPort', '')}-{p.get('ToPort', '')}" if p.get('FromPort') != p.get('ToPort') - else str(p.get('FromPort', '')) - for p in dst_ports - ) if dst_ports else "Any" - + + src_port_str = ( + ", ".join( + ( + f"{p.get('FromPort', '')}-{p.get('ToPort', '')}" + if p.get("FromPort") != p.get("ToPort") + else str(p.get("FromPort", "")) + ) + for p in src_ports + ) + if src_ports + else "Any" + ) + + dst_port_str = ( + ", ".join( + ( + f"{p.get('FromPort', '')}-{p.get('ToPort', '')}" + if p.get("FromPort") != p.get("ToPort") + else str(p.get("FromPort", "")) + ) + for p in dst_ports + ) + if dst_ports + else "Any" + ) + table.add_row( str(i), str(rule.get("priority", "")), @@ -178,25 +193,25 @@ def _show_rule_group(self, _): ", ".join(rule.get("destinations", [])) or "Any", ", ".join(str(p) for p in rule.get("protocols", [])) or "Any", src_port_str, - dst_port_str + dst_port_str, ) - + console.print(table) def _show_policy(self, _): """Show firewall policy with rule groups summary.""" if self.ctx_type != "firewall": return - + policy = self.ctx.data.get("policy", {}) if not policy: console.print("[yellow]No policy data available[/]") return - + from rich.panel import Panel - + console.print(Panel(f"[bold]{policy.get('name', 'N/A')}[/]", title="Policy")) - + # Show rule groups in table format rgs = self.ctx.data.get("rule_groups", []) if rgs: @@ -212,6 +227,6 @@ def _show_policy(self, _): rg["name"], rg["type"], str(len(rg.get("rules", []))), - f"{rg.get('consumed_capacity', 0)}/{rg.get('capacity', 0)}" + f"{rg.get('consumed_capacity', 0)}/{rg.get('capacity', 0)}", ) console.print(table) diff --git a/src/aws_network_tools/shell/handlers/root.py b/src/aws_network_tools/shell/handlers/root.py index b68d2de..ebd9007 100644 --- a/src/aws_network_tools/shell/handlers/root.py +++ b/src/aws_network_tools/shell/handlers/root.py @@ -42,40 +42,46 @@ def _show_regions(self, _): """Show current region scope and available AWS regions.""" from ...core.validators import VALID_AWS_REGIONS from ...core import run_with_spinner - + # Show current scope if self.regions: console.print(f"[bold]Current Scope:[/] {', '.join(self.regions)}") - console.print(f"[dim]Discovery limited to {len(self.regions)} region(s)[/]\n") + console.print( + f"[dim]Discovery limited to {len(self.regions)} region(s)[/]\n" + ) else: console.print("[bold]Current Scope:[/] all regions") console.print("[dim]Discovery will scan all enabled regions[/]\n") - + # Try to fetch actual enabled regions from AWS account enabled_regions = None try: + def fetch_regions(): import boto3 + if self.profile: session = boto3.Session(profile_name=self.profile) else: session = boto3.Session() - ec2 = session.client('ec2', region_name='us-east-1') - response = ec2.describe_regions(AllRegions=False) # Only opted-in regions - return [r['RegionName'] for r in response['Regions']] - + ec2 = session.client("ec2", region_name="us-east-1") + response = ec2.describe_regions( + AllRegions=False + ) # Only opted-in regions + return [r["RegionName"] for r in response["Regions"]] + enabled_regions = run_with_spinner( fetch_regions, "Fetching enabled regions from AWS account", - console=console + console=console, ) except Exception as e: console.print(f"[yellow]Could not fetch enabled regions: {e}[/]") console.print("[dim]Showing all known AWS regions instead[/]\n") - + # Use enabled regions if available, otherwise fall back to static list regions_to_show = set(enabled_regions) if enabled_regions else VALID_AWS_REGIONS - + # Show available AWS regions grouped by area region_groups = { "US": [], @@ -83,7 +89,7 @@ def fetch_regions(): "Asia Pacific": [], "Other": [], } - + for region in sorted(regions_to_show): if region.startswith("us-"): region_groups["US"].append(region) @@ -95,28 +101,28 @@ def fetch_regions(): continue # Skip China regions else: region_groups["Other"].append(region) - + console.print("[bold]Available Regions:[/]") if enabled_regions: console.print("[dim]Showing only regions enabled in your AWS account[/]\n") else: console.print("[dim]Showing all known AWS regions[/]\n") - + for group_name, regions in region_groups.items(): if regions: console.print(f"[cyan]{group_name}:[/]") # Display in rows of 4 for i in range(0, len(regions), 4): - chunk = regions[i:i+4] + chunk = regions[i : i + 4] line = " " + " ".join(f"{r:20}" for r in chunk) console.print(line.rstrip()) console.print() # Blank line between groups - - console.print("[dim]Usage: set regions or set regions all[/]") - def _show_cache(self, _): - from datetime import datetime, timezone + console.print( + "[dim]Usage: set regions or set regions all[/]" + ) + def _show_cache(self, _): table = Table(title="Cache Status") table.add_column("Cache") table.add_column("Entries") @@ -184,7 +190,9 @@ def _show_vpcs(self, _): from ...modules import vpc vpcs = self._cached( - "vpcs", lambda: vpc.VPCClient(self.profile).discover(self.regions), "Fetching VPCs" + "vpcs", + lambda: vpc.VPCClient(self.profile).discover(self.regions), + "Fetching VPCs", ) if not vpcs: console.print("[yellow]No VPCs found[/]") @@ -270,13 +278,16 @@ def _show_enis(self, arg): # EC2HandlersMixin._show_enis shows instance-specific ENIs from ctx.data if self.ctx_type == "ec2-instance": from .ec2 import EC2HandlersMixin + EC2HandlersMixin._show_enis(self, arg) return from ...modules import eni enis_list = self._cached( - "enis", lambda: eni.ENIClient(self.profile).discover(self.regions), "Fetching ENIs" + "enis", + lambda: eni.ENIClient(self.profile).discover(self.regions), + "Fetching ENIs", ) eni.ENIDisplay(console).show_list(enis_list) @@ -296,6 +307,7 @@ def _show_security_groups(self, arg): # VPCHandlersMixin._show_security_groups shows context-specific SGs from ctx.data if self.ctx_type in ("vpc", "ec2-instance"): from .vpc import VPCHandlersMixin + VPCHandlersMixin._show_security_groups(self, arg) return @@ -331,7 +343,9 @@ def _show_resolver_endpoints(self, _): data = self._cached( "route53_resolver", - lambda: route53_resolver.Route53ResolverClient(self.profile).discover(self.regions), + lambda: route53_resolver.Route53ResolverClient(self.profile).discover( + self.regions + ), "Fetching Route 53 Resolver", ) route53_resolver.Route53ResolverDisplay(console).show_endpoints(data) @@ -341,7 +355,9 @@ def _show_resolver_rules(self, _): data = self._cached( "route53_resolver", - lambda: route53_resolver.Route53ResolverClient(self.profile).discover(self.regions), + lambda: route53_resolver.Route53ResolverClient(self.profile).discover( + self.regions + ), "Fetching Route 53 Resolver", ) route53_resolver.Route53ResolverDisplay(console).show_rules(data) @@ -351,7 +367,9 @@ def _show_query_logs(self, _): data = self._cached( "route53_resolver", - lambda: route53_resolver.Route53ResolverClient(self.profile).discover(self.regions), + lambda: route53_resolver.Route53ResolverClient(self.profile).discover( + self.regions + ), "Fetching Route 53 Resolver", ) route53_resolver.Route53ResolverDisplay(console).show_query_logs(data) @@ -381,7 +399,9 @@ def _show_network_alarms(self, _): data = self._cached( "network_alarms", - lambda: network_alarms.NetworkAlarmsClient(self.profile).discover(self.regions), + lambda: network_alarms.NetworkAlarmsClient(self.profile).discover( + self.regions + ), "Fetching network alarms", ) network_alarms.NetworkAlarmsDisplay(console).show_alarms(data) @@ -391,7 +411,9 @@ def _show_alarms_critical(self, _): data = self._cached( "network_alarms", - lambda: network_alarms.NetworkAlarmsClient(self.profile).discover(self.regions), + lambda: network_alarms.NetworkAlarmsClient(self.profile).discover( + self.regions + ), "Fetching network alarms", ) network_alarms.NetworkAlarmsDisplay(console).show_alarms( @@ -413,7 +435,9 @@ def _show_global_accelerators(self, _): data = self._cached( "global_accelerators", - lambda: global_accelerator.GlobalAcceleratorClient(self.profile).discover(self.regions), + lambda: global_accelerator.GlobalAcceleratorClient(self.profile).discover( + self.regions + ), "Fetching Global Accelerators", ) global_accelerator.GlobalAcceleratorDisplay(console).show_accelerators(data) @@ -423,7 +447,9 @@ def _show_ga_endpoint_health(self, _): data = self._cached( "global_accelerators", - lambda: global_accelerator.GlobalAcceleratorClient(self.profile).discover(self.regions), + lambda: global_accelerator.GlobalAcceleratorClient(self.profile).discover( + self.regions + ), "Fetching Global Accelerators", ) global_accelerator.GlobalAcceleratorDisplay(console).show_endpoint_health(data) @@ -451,45 +477,49 @@ def _show_vpc_endpoints(self, _): # Root set handlers def _set_profile(self, val): from ...core.validators import validate_profile - + is_valid, profile, error = validate_profile(val) if not is_valid: console.print(f"[red]{error}[/]") return - + old_profile = self.profile self.profile = profile console.print(f"[green]Profile: {self.profile or '(default)'}[/]") self._sync_runtime_config() - + # Auto-clear cache when profile changes if old_profile != self.profile: count = len(self._cache) if count > 0: self._cache.clear() - console.print(f"[dim]Cleared {count} cache entries (profile changed)[/]") + console.print( + f"[dim]Cleared {count} cache entries (profile changed)[/]" + ) def _set_regions(self, val): from ...core.validators import validate_regions - + is_valid, regions, error = validate_regions(val) if not is_valid: console.print(f"[red]{error}[/]") return - + old_regions = self.regions.copy() self.regions = regions if regions else [] console.print( f"[green]Regions: {', '.join(self.regions) if self.regions else 'all'}[/]" ) self._sync_runtime_config() - + # Auto-clear cache when regions change if old_regions != self.regions: count = len(self._cache) if count > 0: self._cache.clear() - console.print(f"[dim]Cleared {count} cache entries (regions changed)[/]") + console.print( + f"[dim]Cleared {count} cache entries (regions changed)[/]" + ) def _set_no_cache(self, val): self.no_cache = val and val.lower() in ("on", "true", "1", "yes") @@ -498,12 +528,12 @@ def _set_no_cache(self, val): def _set_output_format(self, val): from ...core.validators import validate_output_format - + is_valid, fmt, error = validate_output_format(val) if not is_valid: console.print(f"[red]{error}[/]") return - + self.output_format = fmt console.print(f"[green]Output-format: {fmt}[/]") self._sync_runtime_config() @@ -534,14 +564,14 @@ def _set_watch(self, val): def _set_theme(self, theme_name): """Set color theme (dracula, catppuccin, or custom).""" + from ..themes import load_theme, get_theme_dir + if not theme_name: - console.print(f"[red]Usage: set theme [/]") - console.print(f"[dim]Available themes: dracula, catppuccin[/]") + console.print("[red]Usage: set theme [/]") + console.print("[dim]Available themes: dracula, catppuccin[/]") console.print(f"[dim]Custom themes in: {get_theme_dir()}[/]") return - - from ..themes import load_theme - + try: self.theme = load_theme(theme_name) self.config.set("prompt.theme", theme_name) @@ -560,7 +590,7 @@ def _set_prompt(self, style): console.print("[dim] short: Compact prompt with indices (gl:1>co:1>)[/]") console.print("[dim] long: Multi-line with full names[/]") return - + self.config.set("prompt.style", style) self.config.save() console.print(f"[green]Prompt style set to: {style}[/]") @@ -588,7 +618,7 @@ def _set_global_network(self, val): # Routing cache commands def complete_routing_cache(self, text, line, begidx, endidx): """Tab completion for routing-cache arguments.""" - return ['vpc', 'transit-gateway', 'cloud-wan', 'all'] + return ["vpc", "transit-gateway", "cloud-wan", "all"] def _show_routing_cache(self, arg): """Show routing cache status or detailed routes. @@ -623,13 +653,19 @@ def _show_routing_cache(self, arg): for source, data in cache.items(): routes = data.get("routes", []) regions = set(r.get("region", "?") for r in routes) - source_display = source.replace("tgw", "Transit Gateway").replace("cloudwan", "Cloud WAN").upper() + source_display = ( + source.replace("tgw", "Transit Gateway") + .replace("cloudwan", "Cloud WAN") + .upper() + ) table.add_row(source_display, str(len(routes)), ", ".join(sorted(regions))) console.print(table) total = sum(len(d.get("routes", [])) for d in cache.values()) console.print(f"\n[bold]Total routes cached:[/] {total}") - console.print("\n[dim]Use 'show routing-cache vpc|transit-gateway|cloud-wan|all' for details[/]") + console.print( + "\n[dim]Use 'show routing-cache vpc|transit-gateway|cloud-wan|all' for details[/]" + ) def _show_routing_cache_detail(self, cache, filter_source): """Show detailed routing cache entries.""" @@ -666,7 +702,9 @@ def _show_routing_cache_detail(self, cache, filter_source): if vpc_routes and (filter_source in ["vpc", "all"]): self._show_vpc_routes_table(vpc_routes) - if tgw_routes and (filter_source in ["transit-gateway", "transitgateway", "all"]): + if tgw_routes and ( + filter_source in ["transit-gateway", "transitgateway", "all"] + ): self._show_transit_gateway_routes_table(tgw_routes) if cloudwan_routes and (filter_source in ["cloud-wan", "cloudwan", "all"]): @@ -684,14 +722,16 @@ def _show_vpc_routes_table(self, routes): title=f"VPC Routes ({len(routes)} total)", show_header=True, header_style="bold cyan", - expand=True + expand=True, ) # Balanced column widths (no_wrap + ratio control) table.add_column("VPC Name", style="cyan", no_wrap=not allow_truncate, ratio=2) table.add_column("VPC ID", style="dim", no_wrap=not allow_truncate, ratio=2) table.add_column("Region", style="blue", no_wrap=True, ratio=2) - table.add_column("Route Table", style="yellow", no_wrap=not allow_truncate, ratio=2) + table.add_column( + "Route Table", style="yellow", no_wrap=not allow_truncate, ratio=2 + ) table.add_column("Destination", style="green", no_wrap=True, ratio=2) table.add_column("Target", style="magenta", no_wrap=not allow_truncate, ratio=3) table.add_column("State", style="bold green", no_wrap=True, ratio=1) @@ -704,13 +744,15 @@ def _show_vpc_routes_table(self, routes): r.get("route_table") or "", r.get("destination") or "", r.get("target") or "", - r.get("state") or "" + r.get("state") or "", ) console.print(table) if len(routes) > display_limit: - console.print(f"[dim]Showing first {display_limit} of {len(routes)} routes[/]") - console.print(f"[dim]Set 'pager: true' in config to enable pagination[/]") + console.print( + f"[dim]Showing first {display_limit} of {len(routes)} routes[/]" + ) + console.print("[dim]Set 'pager: true' in config to enable pagination[/]") def _show_transit_gateway_routes_table(self, routes): """Display Transit Gateway routes in detailed table.""" @@ -721,7 +763,7 @@ def _show_transit_gateway_routes_table(self, routes): title=f"Transit Gateway Routes ({len(routes)} total)", show_header=True, header_style="bold cyan", - expand=True # Use full terminal width + expand=True, # Use full terminal width ) # Add columns with proper styling and no width limits @@ -743,7 +785,7 @@ def _show_transit_gateway_routes_table(self, routes): r.get("destination") or "", r.get("target") or "", r.get("state") or "", - r.get("type") or "" + r.get("type") or "", ) console.print(table) @@ -758,7 +800,7 @@ def _show_cloud_wan_routes_table(self, routes): title=f"Cloud WAN Routes ({len(routes)} total)", show_header=True, header_style="bold cyan", - expand=True + expand=True, ) table.add_column("Core Network", style="cyan", no_wrap=not allow_truncate) @@ -779,7 +821,7 @@ def _show_cloud_wan_routes_table(self, routes): r.get("region") or "", r.get("destination") or "", r.get("target") or "", - r.get("state") or "" + r.get("state") or "", ) console.print(table) @@ -853,9 +895,9 @@ def fetch_tgw_routes(): "region": tgw_region, "route_table": rt_id, "destination": r.get("prefix", ""), # Lowercase - "target": r.get("target", ""), # Already lowercase - "state": r.get("state", ""), # Already lowercase - "type": r.get("type", ""), # Already lowercase + "target": r.get("target", ""), # Already lowercase + "state": r.get("state", ""), # Already lowercase + "type": r.get("type", ""), # Already lowercase } ) return routes @@ -916,9 +958,12 @@ def fetch_cloudwan_routes(): if self.config.get("cache.use_local_cache", False): try: from ...core.cache_db import CacheDB + db = CacheDB() saved_count = db.save_routing_cache(cache, self.profile or "default") - console.print(f"[dim] → Saved {saved_count} routes to local database[/]") + console.print( + f"[dim] → Saved {saved_count} routes to local database[/]" + ) except Exception as e: console.print(f"[yellow] → Local DB save failed: {e}[/]") total = sum(len(d.get("routes", [])) for d in cache.values()) @@ -928,11 +973,14 @@ def do_save_routing_cache(self, _): """Save routing cache to local SQLite database.""" cache = self._cache.get("routing-cache", {}) if not cache: - console.print("[yellow]No routing cache to save. Run 'create_routing_cache' first.[/]") + console.print( + "[yellow]No routing cache to save. Run 'create_routing_cache' first.[/]" + ) return try: from ...core.cache_db import CacheDB + db = CacheDB() saved_count = db.save_routing_cache(cache, self.profile or "default") console.print(f"[green]✓ Saved {saved_count} routes to local database[/]") @@ -944,6 +992,7 @@ def do_load_routing_cache(self, _): """Load routing cache from local SQLite database.""" try: from ...core.cache_db import CacheDB + db = CacheDB() cache = db.load_routing_cache(self.profile or "default") @@ -959,7 +1008,11 @@ def do_load_routing_cache(self, _): for source, data in cache.items(): route_count = len(data.get("routes", [])) if route_count > 0: - source_display = source.replace("tgw", "Transit Gateway").replace("cloudwan", "Cloud WAN").upper() + source_display = ( + source.replace("tgw", "Transit Gateway") + .replace("cloudwan", "Cloud WAN") + .upper() + ) console.print(f" {source_display}: {route_count} routes") except Exception as e: @@ -1106,41 +1159,43 @@ def _show_graph(self, arg): def _show_command_path(self, graph, command: str): """Show the path to reach a specific command.""" results = graph.find_command_path(command) - + if not results: console.print(f"[yellow]No command found matching '{command}'[/]") return - + console.print(f"[bold]Paths to '{command}':[/]\n") - + for result in results: marker = "✓" if result["implemented"] else "○" - + if result["is_global"]: console.print(f"{marker} [cyan]{result['command']}[/]") console.print(" [green]Global command[/] - available at root level") else: console.print(f"{marker} [cyan]{result['command']}[/]") console.print(f" Context: [yellow]{result['context']}[/]") - + # Build the full navigation path path_parts = [] if result.get("prereq_show"): path_parts.append(result["prereq_show"]) - + for p in result["path"][:-1]: # Exclude the command itself path_parts.append(p) - + if path_parts: - console.print(f" Path: [blue]{' → '.join(path_parts)} → {result['command']}[/]") - + console.print( + f" Path: [blue]{' → '.join(path_parts)} → {result['command']}[/]" + ) + console.print() def _print_graph_tree(self, node, depth: int): """Print graph as tree with prerequisite show commands for context-entering sets.""" indent = " " * depth marker = "✓" if node.implemented else "○" - + # Map set commands to their prerequisite show commands prereq_show_map = { "set vpc": "show vpcs", @@ -1153,7 +1208,7 @@ def _print_graph_tree(self, node, depth: int): "set vpn": "show vpns", "set route-table": "show route-tables", } - + if node.enters_context: # Show prerequisite show command before set command prereq = prereq_show_map.get(node.name) @@ -1162,7 +1217,7 @@ def _print_graph_tree(self, node, depth: int): console.print(f"{indent}{marker} {node.name} →") else: console.print(f"{indent}{marker} {node.name}") - + for child in node.children: self._print_graph_tree(child, depth + 1) diff --git a/src/aws_network_tools/shell/handlers/tgw.py b/src/aws_network_tools/shell/handlers/tgw.py index 4ea1bdb..8eac4d4 100644 --- a/src/aws_network_tools/shell/handlers/tgw.py +++ b/src/aws_network_tools/shell/handlers/tgw.py @@ -25,7 +25,9 @@ def _set_transit_gateway(self, val): selection_idx = int(val) except ValueError: selection_idx = 1 - self._enter("transit-gateway", t["id"], t.get("name", t["id"]), t, selection_idx) + self._enter( + "transit-gateway", t["id"], t.get("name", t["id"]), t, selection_idx + ) print() # Add blank line before next prompt def _show_transit_gateway_route_tables(self): diff --git a/src/aws_network_tools/shell/handlers/utilities.py b/src/aws_network_tools/shell/handlers/utilities.py index 658d7e3..81d0b6c 100644 --- a/src/aws_network_tools/shell/handlers/utilities.py +++ b/src/aws_network_tools/shell/handlers/utilities.py @@ -1,7 +1,8 @@ """Utility command handlers (trace, find_ip, run, cache, write).""" -from rich.console import Console import boto3 +from rich.console import Console +from rich.table import Table console = Console() diff --git a/src/aws_network_tools/shell/handlers/vpc.py b/src/aws_network_tools/shell/handlers/vpc.py index 1c1dd20..c931d13 100644 --- a/src/aws_network_tools/shell/handlers/vpc.py +++ b/src/aws_network_tools/shell/handlers/vpc.py @@ -57,7 +57,9 @@ def _set_vpc_route_table(self, val): selection_idx = int(val) except ValueError: selection_idx = 1 - self._enter("route-table", rt["id"], rt.get("name") or rt["id"], rt, selection_idx) + self._enter( + "route-table", rt["id"], rt.get("name") or rt["id"], rt, selection_idx + ) print() # Add blank line before next prompt def _show_vpc_route_tables(self): diff --git a/src/aws_network_tools/shell/handlers/vpn.py b/src/aws_network_tools/shell/handlers/vpn.py index 6b156ba..69a5af8 100644 --- a/src/aws_network_tools/shell/handlers/vpn.py +++ b/src/aws_network_tools/shell/handlers/vpn.py @@ -25,11 +25,12 @@ def _set_vpn(self, val): selection_idx = int(val) except ValueError: selection_idx = 1 - + # Fetch full VPN details including tunnels from ...modules import vpn + vpn_detail = vpn.VPNClient(self.profile).get_vpn_detail(v["id"], v["region"]) - + self._enter("vpn", v["id"], v.get("name", v["id"]), vpn_detail, selection_idx) print() # Add blank line before next prompt @@ -38,7 +39,9 @@ def _show_vpns(self, _): from ...modules import vpn vpns = self._cached( - "vpns", lambda: vpn.VPNClient(self.profile).discover(self.regions), "Fetching VPNs" + "vpns", + lambda: vpn.VPNClient(self.profile).discover(self.regions), + "Fetching VPNs", ) if not vpns: console.print("[yellow]No VPN connections found[/]") diff --git a/src/aws_network_tools/shell/main.py b/src/aws_network_tools/shell/main.py index 0d4a630..5d149f3 100644 --- a/src/aws_network_tools/shell/main.py +++ b/src/aws_network_tools/shell/main.py @@ -399,10 +399,16 @@ def do_find_null_routes(self, _): def main(): import argparse - parser = argparse.ArgumentParser(description='AWS Network Tools Interactive Shell') - parser.add_argument('--profile', '-p', help='AWS profile to use') - parser.add_argument('--no-cache', action='store_true', help='Disable caching') - parser.add_argument('--format', choices=['table', 'json', 'yaml'], default='table', help='Output format') + + parser = argparse.ArgumentParser(description="AWS Network Tools Interactive Shell") + parser.add_argument("--profile", "-p", help="AWS profile to use") + parser.add_argument("--no-cache", action="store_true", help="Disable caching") + parser.add_argument( + "--format", + choices=["table", "json", "yaml"], + default="table", + help="Output format", + ) args, unknown = parser.parse_known_args() diff --git a/src/aws_network_tools/themes/__init__.py b/src/aws_network_tools/themes/__init__.py index 8138ba1..e7bfea9 100644 --- a/src/aws_network_tools/themes/__init__.py +++ b/src/aws_network_tools/themes/__init__.py @@ -1,83 +1,95 @@ """Theme system for AWS Network Shell.""" -from pathlib import Path -from typing import Dict, Any, Optional import json +from pathlib import Path +from typing import Dict, Optional class Theme: """Color theme for prompts and UI.""" - + def __init__(self, name: str, colors: Dict[str, str]): self.name = name self.colors = colors - + def get(self, key: str, default: str = "white") -> str: """Get color for a context type.""" return self.colors.get(key, default) # Built-in themes -DRACULA_THEME = Theme("dracula", { - "root": "#f8f8f2", # Foreground - "global-network": "#bd93f9", # Purple - "core-network": "#ff79c6", # Pink - "route-table": "#8be9fd", # Cyan - "vpc": "#50fa7b", # Green - "transit-gateway": "#ffb86c", # Orange - "firewall": "#ff5555", # Red - "elb": "#f1fa8c", # Yellow - "vpn": "#6272a4", # Comment - "ec2-instance": "#bd93f9", # Purple - "prompt_separator": "#6272a4", - "prompt_text": "#f8f8f2", -}) +DRACULA_THEME = Theme( + "dracula", + { + "root": "#f8f8f2", # Foreground + "global-network": "#bd93f9", # Purple + "core-network": "#ff79c6", # Pink + "route-table": "#8be9fd", # Cyan + "vpc": "#50fa7b", # Green + "transit-gateway": "#ffb86c", # Orange + "firewall": "#ff5555", # Red + "elb": "#f1fa8c", # Yellow + "vpn": "#6272a4", # Comment + "ec2-instance": "#bd93f9", # Purple + "prompt_separator": "#6272a4", + "prompt_text": "#f8f8f2", + }, +) # Catppuccin variants -CATPPUCCIN_LATTE_THEME = Theme("catppuccin-latte", { - "root": "#4c4f69", - "global-network": "#8839ef", - "core-network": "#ea76cb", - "route-table": "#04a5e5", - "vpc": "#40a02b", - "transit-gateway": "#fe640b", - "firewall": "#d20f39", - "elb": "#df8e1d", - "vpn": "#6c6f85", - "ec2-instance": "#8839ef", - "prompt_separator": "#6c6f85", - "prompt_text": "#4c4f69", -}) - -CATPPUCCIN_MACCHIATO_THEME = Theme("catppuccin-macchiato", { - "root": "#cad3f5", - "global-network": "#c6a0f6", - "core-network": "#f5bde6", - "route-table": "#7dc4e4", - "vpc": "#a6da95", - "transit-gateway": "#f5a97f", - "firewall": "#ed8796", - "elb": "#eed49f", - "vpn": "#939ab7", - "ec2-instance": "#c6a0f6", - "prompt_separator": "#939ab7", - "prompt_text": "#cad3f5", -}) - -CATPPUCCIN_MOCHA_THEME = Theme("catppuccin-mocha", { - "root": "#89b4fa", # Blue (more vibrant than grey) - "global-network": "#cba6f7", # Mauve - "core-network": "#f5c2e7", # Pink - "route-table": "#94e2d5", # Teal (brighter than sky) - "vpc": "#a6e3a1", # Green - "transit-gateway": "#fab387", # Peach - "firewall": "#f38ba8", # Red - "elb": "#f9e2af", # Yellow - "vpn": "#b4befe", # Lavender (brighter than overlay) - "ec2-instance": "#cba6f7", # Mauve - "prompt_separator": "#9399b2", # Overlay1 (brighter) - "prompt_text": "#89b4fa", # Blue -}) +CATPPUCCIN_LATTE_THEME = Theme( + "catppuccin-latte", + { + "root": "#4c4f69", + "global-network": "#8839ef", + "core-network": "#ea76cb", + "route-table": "#04a5e5", + "vpc": "#40a02b", + "transit-gateway": "#fe640b", + "firewall": "#d20f39", + "elb": "#df8e1d", + "vpn": "#6c6f85", + "ec2-instance": "#8839ef", + "prompt_separator": "#6c6f85", + "prompt_text": "#4c4f69", + }, +) + +CATPPUCCIN_MACCHIATO_THEME = Theme( + "catppuccin-macchiato", + { + "root": "#cad3f5", + "global-network": "#c6a0f6", + "core-network": "#f5bde6", + "route-table": "#7dc4e4", + "vpc": "#a6da95", + "transit-gateway": "#f5a97f", + "firewall": "#ed8796", + "elb": "#eed49f", + "vpn": "#939ab7", + "ec2-instance": "#c6a0f6", + "prompt_separator": "#939ab7", + "prompt_text": "#cad3f5", + }, +) + +CATPPUCCIN_MOCHA_THEME = Theme( + "catppuccin-mocha", + { + "root": "#89b4fa", # Blue (more vibrant than grey) + "global-network": "#cba6f7", # Mauve + "core-network": "#f5c2e7", # Pink + "route-table": "#94e2d5", # Teal (brighter than sky) + "vpc": "#a6e3a1", # Green + "transit-gateway": "#fab387", # Peach + "firewall": "#f38ba8", # Red + "elb": "#f9e2af", # Yellow + "vpn": "#b4befe", # Lavender (brighter than overlay) + "ec2-instance": "#cba6f7", # Mauve + "prompt_separator": "#9399b2", # Overlay1 (brighter) + "prompt_text": "#89b4fa", # Blue + }, +) # Default theme (Mocha - the darkest Catppuccin variant) DEFAULT_THEME = CATPPUCCIN_MOCHA_THEME @@ -87,7 +99,7 @@ def load_theme_from_file(path: Path) -> Optional[Theme]: """Load theme from JSON file.""" if not path.exists(): return None - + try: data = json.loads(path.read_text()) return Theme(data.get("name", "custom"), data.get("colors", {})) @@ -108,7 +120,7 @@ def load_theme(name: Optional[str] = None) -> Theme: return DRACULA_THEME if name.lower() in {"catppuccin", "catpuccin"}: # Common misspelling return CATPPUCCIN_MOCHA_THEME # Default to Mocha variant - + # Check custom themes directory theme_dir = get_theme_dir() theme_file = theme_dir / f"{name}.json" @@ -116,6 +128,6 @@ def load_theme(name: Optional[str] = None) -> Theme: custom_theme = load_theme_from_file(theme_file) if custom_theme: return custom_theme - + # Fall back to default - return DEFAULT_THEME \ No newline at end of file + return DEFAULT_THEME diff --git a/tests/agent_test_runner.py b/tests/agent_test_runner.py old mode 100644 new mode 100755 index b333b36..74d686e --- a/tests/agent_test_runner.py +++ b/tests/agent_test_runner.py @@ -13,12 +13,12 @@ from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import Any @dataclass class TestResult: """Result of a single test execution.""" + test_id: str command: str passed: bool = False @@ -32,6 +32,7 @@ class TestResult: @dataclass class TestState: """Persistent state across test execution.""" + baseline: dict = field(default_factory=dict) extracted: dict = field(default_factory=dict) context_stack: list[str] = field(default_factory=list) @@ -73,8 +74,6 @@ def _load_baselines(self): def start_shell(self): """Start the interactive shell process using PTY for proper TTY behavior.""" import pty - import os - import time cmd = f"cd {self.working_dir} && source .venv/bin/activate && aws-net-shell --profile {self.profile}" @@ -88,7 +87,7 @@ def start_shell(self): stdout=slave_fd, stderr=slave_fd, executable="/bin/bash", - close_fds=True + close_fds=True, ) # Close slave fd in parent (child keeps it open) @@ -102,11 +101,12 @@ def start_shell(self): print(f"✓ Shell started (captured {len(startup_output)} chars during init)") - def _read_until_prompt_pty(self, timeout: float = 30.0, idle_timeout: float = 2.0) -> str: + def _read_until_prompt_pty( + self, timeout: float = 30.0, idle_timeout: float = 2.0 + ) -> str: """Read shell output from PTY until we see a prompt.""" import select import time - import os output = [] start_time = time.time() @@ -125,13 +125,17 @@ def _read_until_prompt_pty(self, timeout: float = 30.0, idle_timeout: float = 2. chunk = os.read(self.master_fd, 1024) if not chunk: break - decoded = chunk.decode('utf-8', errors='replace') + decoded = chunk.decode("utf-8", errors="replace") output.append(decoded) last_char_time = time.time() # Check for common prompts current = "".join(output)[-50:] - if current.endswith("> ") or current.endswith("# ") or current.endswith("$ "): + if ( + current.endswith("> ") + or current.endswith("# ") + or current.endswith("$ ") + ): break except OSError: break @@ -145,14 +149,13 @@ def _read_until_prompt_pty(self, timeout: float = 30.0, idle_timeout: float = 2. def run_command(self, command: str) -> str: """Send command to shell via PTY and capture output.""" - import os import time if not self.shell_process: self.start_shell() # Send command via PTY - os.write(self.master_fd, (command + "\n").encode('utf-8')) + os.write(self.master_fd, (command + "\n").encode("utf-8")) time.sleep(0.5) # Allow command to start processing # Use 5min overall timeout for slow AWS API calls, 10s idle for table rendering @@ -179,16 +182,18 @@ def run_aws_cli(self, command: str) -> dict | list | None: print(f" AWS CLI error: {e}") return None - def validate_count(self, output: str, expected_count: int, item_type: str) -> tuple[bool, str]: + def validate_count( + self, output: str, expected_count: int, item_type: str + ) -> tuple[bool, str]: """Validate row count in table output.""" # Count table rows - Rich tables use │ as column separator # Strip ANSI escape codes for accurate parsing - ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") lines = output.strip().split("\n") table_rows = [] for line in lines: # Strip ANSI codes for parsing - clean_line = ansi_escape.sub('', line) + clean_line = ansi_escape.sub("", line) # Match Rich table rows: │ 1 │ or standard rows starting with number if "│" in clean_line: @@ -202,9 +207,14 @@ def validate_count(self, output: str, expected_count: int, item_type: str) -> tu if actual_count == expected_count: return True, f"✓ {item_type} count: {actual_count}" else: - return False, f"✗ {item_type} count: expected {expected_count}, got {actual_count}" + return ( + False, + f"✗ {item_type} count: expected {expected_count}, got {actual_count}", + ) - def validate_ids_present(self, output: str, expected_ids: list[str]) -> tuple[bool, str]: + def validate_ids_present( + self, output: str, expected_ids: list[str] + ) -> tuple[bool, str]: """Validate that expected IDs appear in output.""" missing = [id_ for id_ in expected_ids if id_ not in output] if not missing: @@ -228,10 +238,10 @@ def extract_table_values(self, output: str, column_index: int = 0) -> list[str]: def extract_first_row_number(self, output: str) -> str | None: """Extract the row number from first data row.""" # Strip ANSI escape codes - ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") lines = output.strip().split("\n") for line in lines: - clean_line = ansi_escape.sub('', line) + clean_line = ansi_escape.sub("", line) # Handle Rich table format: │ 1 │ ... if "│" in clean_line: @@ -244,18 +254,22 @@ def extract_first_row_number(self, output: str) -> str | None: return match.group(1) return None - def run_test(self, test_id: str, command: str, - baseline_key: str | None = None, - baseline_command: str | None = None, - validations: list[dict] | None = None, - extractions: list[dict] | None = None) -> TestResult: + def run_test( + self, + test_id: str, + command: str, + baseline_key: str | None = None, + baseline_command: str | None = None, + validations: list[dict] | None = None, + extractions: list[dict] | None = None, + ) -> TestResult: """Execute a single test with validation.""" result = TestResult(test_id=test_id, command=command) - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"TEST: {test_id}") print(f"COMMAND: {command}") - print(f"{'='*60}") + print(f"{'=' * 60}") # Substitute variables in command for key, value in self.state.extracted.items(): @@ -292,7 +306,9 @@ def run_test(self, test_id: str, command: str, if validations: for v in validations: if v["type"] == "count": - passed, msg = self.validate_count(result.output, v["expected"], v.get("item", "items")) + passed, msg = self.validate_count( + result.output, v["expected"], v.get("item", "items") + ) result.details.append(msg) if not passed: result.passed = False @@ -323,9 +339,9 @@ def run_test(self, test_id: str, command: str, def run_phase_1_baseline(self): """Phase 1: Verify baselines are loaded.""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("PHASE 1: INFRASTRUCTURE BASELINE") - print("="*60) + print("=" * 60) baseline = self.state.baseline @@ -337,15 +353,20 @@ def run_phase_1_baseline(self): if "tgws" in baseline and baseline["tgws"]: self.state.extracted["tgw_count"] = len(baseline["tgws"]) - self.state.extracted["tgw_ids"] = [t["TransitGatewayId"] for t in baseline["tgws"]] + self.state.extracted["tgw_ids"] = [ + t["TransitGatewayId"] for t in baseline["tgws"] + ] print(f"✓ TGWs: {len(baseline['tgws'])}") if "vpns" in baseline and baseline["vpns"]: self.state.extracted["vpn_count"] = len(baseline["vpns"]) - self.state.extracted["vpn_ids"] = [v["VpnConnectionId"] for v in baseline["vpns"]] + self.state.extracted["vpn_ids"] = [ + v["VpnConnectionId"] for v in baseline["vpns"] + ] # Check tunnel status tunnels_up = sum( - 1 for v in baseline["vpns"] + 1 + for v in baseline["vpns"] for t in (v.get("VgwTelemetry") or []) if t.get("Status") == "UP" ) @@ -371,9 +392,9 @@ def run_phase_1_baseline(self): def run_phase_2_root_commands(self): """Phase 2: Test root level commands.""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("PHASE 2: ROOT LEVEL COMMANDS") - print("="*60) + print("=" * 60) # Test: show vpcs (multi-region - just check no errors and extraction works) self.run_test( @@ -382,7 +403,10 @@ def run_phase_2_root_commands(self): baseline_key="vpcs", validations=[ # Shell queries all regions, so just check baseline VPC is present - {"type": "ids_present", "ids": self.state.extracted.get("vpc_ids", [])[:5]}, + { + "type": "ids_present", + "ids": self.state.extracted.get("vpc_ids", [])[:5], + }, ], extractions=[{"type": "first_row_number", "key": "first_vpc_number"}], ) @@ -415,16 +439,20 @@ def run_phase_2_root_commands(self): test_id="ROOT_005", command="show global-networks", validations=[ - {"type": "count", "expected": self.state.extracted.get("gn_count", 0), "item": "Global Networks"}, + { + "type": "count", + "expected": self.state.extracted.get("gn_count", 0), + "item": "Global Networks", + }, ], extractions=[{"type": "first_row_number", "key": "first_gn_number"}], ) def run_phase_3_vpc_context(self): """Phase 3: Test VPC context commands.""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("PHASE 3: VPC CONTEXT") - print("="*60) + print("=" * 60) first_vpc = self.state.extracted.get("first_vpc_number") if not first_vpc: @@ -475,9 +503,9 @@ def run_phase_3_vpc_context(self): def run_phase_4_tgw_context(self): """Phase 4: Test Transit Gateway context commands.""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("PHASE 4: TRANSIT GATEWAY CONTEXT") - print("="*60) + print("=" * 60) first_tgw = self.state.extracted.get("first_tgw_number") if not first_tgw: @@ -514,9 +542,9 @@ def run_phase_4_tgw_context(self): def run_phase_5_vpn_context(self): """Phase 5: Test VPN context commands (requires active tunnel).""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("PHASE 5: VPN CONTEXT") - print("="*60) + print("=" * 60) tunnels_up = self.state.extracted.get("tunnels_up", 0) if tunnels_up == 0: @@ -551,9 +579,9 @@ def run_phase_5_vpn_context(self): def run_phase_6_firewall_context(self): """Phase 6: Test Firewall context commands.""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("PHASE 6: FIREWALL CONTEXT") - print("="*60) + print("=" * 60) first_fw = self.state.extracted.get("first_fw_number") if not first_fw: @@ -582,9 +610,9 @@ def run_phase_6_firewall_context(self): def run_phase_7_cloudwan_context(self): """Phase 7: Test CloudWAN context commands.""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("PHASE 7: CLOUDWAN CONTEXT") - print("="*60) + print("=" * 60) first_gn = self.state.extracted.get("first_gn_number") if not first_gn: @@ -662,32 +690,35 @@ def generate_report(self) -> dict: def cleanup(self): """Clean up shell process and PTY file descriptor.""" - import os if self.shell_process: try: # Try to send exit command via PTY os.write(self.master_fd, b"exit\n") - except: + except OSError: pass self.shell_process.terminate() try: self.shell_process.wait(timeout=5) - except: + except subprocess.TimeoutExpired: self.shell_process.kill() # Close PTY file descriptor - if hasattr(self, 'master_fd'): + if hasattr(self, "master_fd"): try: os.close(self.master_fd) - except: + except OSError: pass def main(): parser = argparse.ArgumentParser(description="AWS Network Shell Test Runner") - parser.add_argument("--profile", default="taylaand+net-dev-Admin", help="AWS profile") + parser.add_argument( + "--profile", default="taylaand+net-dev-Admin", help="AWS profile" + ) parser.add_argument("--working-dir", default=".", help="Working directory") - parser.add_argument("--baseline-dir", default="/tmp", help="Baseline files directory") + parser.add_argument( + "--baseline-dir", default="/tmp", help="Baseline files directory" + ) parser.add_argument("--output", default=None, help="Output report file") args = parser.parse_args() @@ -712,9 +743,9 @@ def main(): # Generate report report = runner.generate_report() - print("\n" + "="*60) + print("\n" + "=" * 60) print("FINAL REPORT") - print("="*60) + print("=" * 60) print(f"Total: {report['summary']['total']}") print(f"Passed: {report['summary']['passed']}") print(f"Failed: {report['summary']['failed']}") diff --git a/tests/fixtures/client_vpn.py b/tests/fixtures/client_vpn.py index 2b92a05..c88abdb 100644 --- a/tests/fixtures/client_vpn.py +++ b/tests/fixtures/client_vpn.py @@ -539,9 +539,7 @@ def get_endpoints_by_vpc(vpc_id: str) -> list[dict[str, Any]]: List of endpoints in the VPC """ return [ - ep - for ep in CLIENT_VPN_ENDPOINT_FIXTURES.values() - if ep.get("VpcId") == vpc_id + ep for ep in CLIENT_VPN_ENDPOINT_FIXTURES.values() if ep.get("VpcId") == vpc_id ] diff --git a/tests/fixtures/cloudwan_connect.py b/tests/fixtures/cloudwan_connect.py index c2c1bcf..7a2b5c2 100644 --- a/tests/fixtures/cloudwan_connect.py +++ b/tests/fixtures/cloudwan_connect.py @@ -551,9 +551,7 @@ def get_connect_peers_by_edge_location(edge_location: str) -> list[dict[str, Any def get_connect_peers_by_state(state: str) -> list[dict[str, Any]]: """Get all Connect Peers with a specific state (AVAILABLE, CREATING, DELETING).""" - return [ - peer for peer in CONNECT_PEER_FIXTURES.values() if peer["State"] == state - ] + return [peer for peer in CONNECT_PEER_FIXTURES.values() if peer["State"] == state] def get_connect_peers_by_asn(peer_asn: int) -> list[dict[str, Any]]: diff --git a/tests/fixtures/command_fixtures.py b/tests/fixtures/command_fixtures.py index 0da2254..6365dd5 100644 --- a/tests/fixtures/command_fixtures.py +++ b/tests/fixtures/command_fixtures.py @@ -3,18 +3,26 @@ Maps every shell command to its fixture data and mock targets. Uses actual module Client.discover() pattern. """ + from typing import Any from . import ( - VPC_FIXTURES, SUBNET_FIXTURES, ROUTE_TABLE_FIXTURES, - SECURITY_GROUP_FIXTURES, NACL_FIXTURES, - TGW_FIXTURES, TGW_ATTACHMENT_FIXTURES, TGW_ROUTE_TABLE_FIXTURES, - CLOUDWAN_FIXTURES, CLOUDWAN_ATTACHMENT_FIXTURES, - EC2_INSTANCE_FIXTURES, ENI_FIXTURES, - ELB_FIXTURES, TARGET_GROUP_FIXTURES, LISTENER_FIXTURES, + VPC_FIXTURES, + SUBNET_FIXTURES, + ROUTE_TABLE_FIXTURES, + SECURITY_GROUP_FIXTURES, + NACL_FIXTURES, + TGW_FIXTURES, + TGW_ATTACHMENT_FIXTURES, + TGW_ROUTE_TABLE_FIXTURES, + CLOUDWAN_FIXTURES, + EC2_INSTANCE_FIXTURES, + ELB_FIXTURES, VPN_CONNECTION_FIXTURES, - NETWORK_FIREWALL_FIXTURES, FIREWALL_POLICY_FIXTURES, RULE_GROUP_FIXTURES, - IGW_FIXTURES, NAT_GATEWAY_FIXTURES, - get_vpc_detail, get_tgw_detail, get_elb_detail, get_vpn_detail, get_firewall_detail, + NETWORK_FIREWALL_FIXTURES, + RULE_GROUP_FIXTURES, + IGW_FIXTURES, + NAT_GATEWAY_FIXTURES, + get_vpc_detail, ) @@ -30,74 +38,130 @@ def get_tag_value(resource: dict, key: str = "Name") -> str: # FIXTURE DATA GENERATORS - Match shell's expected format # ============================================================================= + def _vpcs_list(): """Generate VPC list in shell's expected format.""" return [ - {"id": vpc_id, "name": get_tag_value(vpc) or vpc_id, "region": "eu-west-1", - "cidr": vpc["CidrBlock"], "cidrs": [vpc["CidrBlock"]], "state": vpc["State"]} + { + "id": vpc_id, + "name": get_tag_value(vpc) or vpc_id, + "region": "eu-west-1", + "cidr": vpc["CidrBlock"], + "cidrs": [vpc["CidrBlock"]], + "state": vpc["State"], + } for vpc_id, vpc in VPC_FIXTURES.items() ] + def _tgws_list(): """Generate TGW list in shell's expected format.""" return [ - {"id": tgw_id, "name": get_tag_value(tgw) or tgw_id, "region": "eu-west-1", - "state": tgw["State"], "attachments": [], "route_tables": []} + { + "id": tgw_id, + "name": get_tag_value(tgw) or tgw_id, + "region": "eu-west-1", + "state": tgw["State"], + "attachments": [], + "route_tables": [], + } for tgw_id, tgw in TGW_FIXTURES.items() ] + def _global_networks_list(): """Generate global networks list.""" return [ - {"id": "global-network-0prod123456789", "name": "production-global-network", - "state": "AVAILABLE", "description": "Production global network"} + { + "id": "global-network-0prod123456789", + "name": "production-global-network", + "state": "AVAILABLE", + "description": "Production global network", + } ] + def _core_networks_list(): """Generate core networks list.""" return [ - {"id": cn_id, "name": get_tag_value(cn) or cn_id, - "global_network_id": cn.get("GlobalNetworkId", "global-network-123"), - "state": cn.get("State", "AVAILABLE"), "segments": ["production", "development"], - "regions": ["eu-west-1", "us-east-1"], "nfgs": [], "route_tables": [], "policy": {}} + { + "id": cn_id, + "name": get_tag_value(cn) or cn_id, + "global_network_id": cn.get("GlobalNetworkId", "global-network-123"), + "state": cn.get("State", "AVAILABLE"), + "segments": ["production", "development"], + "regions": ["eu-west-1", "us-east-1"], + "nfgs": [], + "route_tables": [], + "policy": {}, + } for cn_id, cn in CLOUDWAN_FIXTURES.items() ] + def _firewalls_list(): """Generate firewall list.""" return [ - {"name": fw["FirewallName"], "arn": fw["FirewallArn"], "region": "eu-west-1", - "vpc_id": fw["VpcId"], "status": fw.get("FirewallStatus", {}).get("Status", "READY"), - "policy_arn": fw.get("FirewallPolicyArn", ""), "rule_groups": []} + { + "name": fw["FirewallName"], + "arn": fw["FirewallArn"], + "region": "eu-west-1", + "vpc_id": fw["VpcId"], + "status": fw.get("FirewallStatus", {}).get("Status", "READY"), + "policy_arn": fw.get("FirewallPolicyArn", ""), + "rule_groups": [], + } for fw in NETWORK_FIREWALL_FIXTURES.values() ] + def _ec2_instances_list(): """Generate EC2 instances list.""" return [ - {"id": inst_id, "name": get_tag_value(inst) or inst_id, "region": "eu-west-1", - "type": inst["InstanceType"], "state": inst["State"]["Name"], - "private_ip": inst.get("PrivateIpAddress"), "public_ip": inst.get("PublicIpAddress"), - "vpc_id": inst.get("VpcId"), "subnet_id": inst.get("SubnetId")} + { + "id": inst_id, + "name": get_tag_value(inst) or inst_id, + "region": "eu-west-1", + "type": inst["InstanceType"], + "state": inst["State"]["Name"], + "private_ip": inst.get("PrivateIpAddress"), + "public_ip": inst.get("PublicIpAddress"), + "vpc_id": inst.get("VpcId"), + "subnet_id": inst.get("SubnetId"), + } for inst_id, inst in EC2_INSTANCE_FIXTURES.items() ] + def _elbs_list(): """Generate ELB list.""" return [ - {"arn": elb["LoadBalancerArn"], "name": elb["LoadBalancerName"], - "type": elb["Type"], "scheme": elb["Scheme"], "state": elb["State"]["Code"], - "dns_name": elb["DNSName"], "vpc_id": elb["VpcId"], "region": "eu-west-1"} + { + "arn": elb["LoadBalancerArn"], + "name": elb["LoadBalancerName"], + "type": elb["Type"], + "scheme": elb["Scheme"], + "state": elb["State"]["Code"], + "dns_name": elb["DNSName"], + "vpc_id": elb["VpcId"], + "region": "eu-west-1", + } for elb in ELB_FIXTURES.values() ] + def _vpns_list(): """Generate VPN list.""" return [ - {"id": vpn_id, "name": get_tag_value(vpn) or vpn_id, "region": "eu-west-1", - "state": vpn["State"], "type": vpn["Type"], - "customer_gateway_id": vpn["CustomerGatewayId"], - "tunnels": vpn.get("VgwTelemetry", [])} + { + "id": vpn_id, + "name": get_tag_value(vpn) or vpn_id, + "region": "eu-west-1", + "state": vpn["State"], + "type": vpn["Type"], + "customer_gateway_id": vpn["CustomerGatewayId"], + "tunnels": vpn.get("VgwTelemetry", []), + } for vpn_id, vpn in VPN_CONNECTION_FIXTURES.items() ] @@ -138,72 +202,120 @@ def _vpns_list(): "mock_target": "aws_network_tools.modules.vpn.VPNClient.discover", "fixture_data": _vpns_list, }, - # ========================================================================= # VPC CONTEXT COMMANDS # ========================================================================= "vpc.show detail": { "mock_target": "aws_network_tools.modules.vpc.VPCClient.get_vpc_detail", - "fixture_data": lambda vpc_id=None: get_vpc_detail(vpc_id or list(VPC_FIXTURES.keys())[0]) or { - "id": vpc_id, "name": "test-vpc", "region": "eu-west-1", - "cidr": "10.0.0.0/16", "cidrs": ["10.0.0.0/16"], "state": "available", - "route_tables": [], "security_groups": [], "nacls": [] + "fixture_data": lambda vpc_id=None: get_vpc_detail( + vpc_id or list(VPC_FIXTURES.keys())[0] + ) + or { + "id": vpc_id, + "name": "test-vpc", + "region": "eu-west-1", + "cidr": "10.0.0.0/16", + "cidrs": ["10.0.0.0/16"], + "state": "available", + "route_tables": [], + "security_groups": [], + "nacls": [], }, }, "vpc.show subnets": { "mock_target": "aws_network_tools.modules.vpc.VPCClient.get_vpc_detail", - "fixture_data": lambda: {"subnets": [ - {"id": s["SubnetId"], "name": get_tag_value(s), "cidr": s.get("CidrBlock", "10.0.0.0/24"), - "az": s["AvailabilityZone"], "state": s["State"]} - for s in list(SUBNET_FIXTURES.values())[:5] - ]}, + "fixture_data": lambda: { + "subnets": [ + { + "id": s["SubnetId"], + "name": get_tag_value(s), + "cidr": s.get("CidrBlock", "10.0.0.0/24"), + "az": s["AvailabilityZone"], + "state": s["State"], + } + for s in list(SUBNET_FIXTURES.values())[:5] + ] + }, }, "vpc.show route-tables": { "mock_target": "aws_network_tools.modules.vpc.VPCClient.get_vpc_detail", - "fixture_data": lambda: {"route_tables": [ - {"id": rt["RouteTableId"], "name": get_tag_value(rt), - "is_main": any(a.get("Main") for a in rt.get("Associations", [])), - "routes": rt.get("Routes", []), - "subnets": [a["SubnetId"] for a in rt.get("Associations", []) if a.get("SubnetId")]} - for rt in list(ROUTE_TABLE_FIXTURES.values())[:3] - ]}, + "fixture_data": lambda: { + "route_tables": [ + { + "id": rt["RouteTableId"], + "name": get_tag_value(rt), + "is_main": any(a.get("Main") for a in rt.get("Associations", [])), + "routes": rt.get("Routes", []), + "subnets": [ + a["SubnetId"] + for a in rt.get("Associations", []) + if a.get("SubnetId") + ], + } + for rt in list(ROUTE_TABLE_FIXTURES.values())[:3] + ] + }, }, "vpc.show security-groups": { "mock_target": "aws_network_tools.modules.vpc.VPCClient.get_vpc_detail", - "fixture_data": lambda: {"security_groups": [ - {"id": sg["GroupId"], "name": sg["GroupName"], "description": sg.get("Description", ""), - "ingress": [], "egress": []} - for sg in list(SECURITY_GROUP_FIXTURES.values())[:3] - ]}, + "fixture_data": lambda: { + "security_groups": [ + { + "id": sg["GroupId"], + "name": sg["GroupName"], + "description": sg.get("Description", ""), + "ingress": [], + "egress": [], + } + for sg in list(SECURITY_GROUP_FIXTURES.values())[:3] + ] + }, }, "vpc.show nacls": { "mock_target": "aws_network_tools.modules.vpc.VPCClient.get_vpc_detail", - "fixture_data": lambda: {"nacls": [ - {"id": nacl["NetworkAclId"], "name": get_tag_value(nacl), - "is_default": nacl.get("IsDefault", False), "entries": []} - for nacl in list(NACL_FIXTURES.values())[:2] - ]}, + "fixture_data": lambda: { + "nacls": [ + { + "id": nacl["NetworkAclId"], + "name": get_tag_value(nacl), + "is_default": nacl.get("IsDefault", False), + "entries": [], + } + for nacl in list(NACL_FIXTURES.values())[:2] + ] + }, }, "vpc.show internet-gateways": { "mock_target": "aws_network_tools.modules.vpc.VPCClient.get_vpc_detail", - "fixture_data": lambda: {"internet_gateways": [ - {"id": igw["InternetGatewayId"], "name": get_tag_value(igw), "state": "attached"} - for igw in list(IGW_FIXTURES.values())[:1] - ]}, + "fixture_data": lambda: { + "internet_gateways": [ + { + "id": igw["InternetGatewayId"], + "name": get_tag_value(igw), + "state": "attached", + } + for igw in list(IGW_FIXTURES.values())[:1] + ] + }, }, "vpc.show nat-gateways": { "mock_target": "aws_network_tools.modules.vpc.VPCClient.get_vpc_detail", - "fixture_data": lambda: {"nat_gateways": [ - {"id": nat["NatGatewayId"], "name": get_tag_value(nat), "state": nat["State"], - "subnet_id": nat["SubnetId"]} - for nat in list(NAT_GATEWAY_FIXTURES.values())[:2] - ]}, + "fixture_data": lambda: { + "nat_gateways": [ + { + "id": nat["NatGatewayId"], + "name": get_tag_value(nat), + "state": nat["State"], + "subnet_id": nat["SubnetId"], + } + for nat in list(NAT_GATEWAY_FIXTURES.values())[:2] + ] + }, }, "vpc.show endpoints": { "mock_target": "aws_network_tools.modules.vpc.VPCClient.get_vpc_detail", "fixture_data": lambda: {"endpoints": []}, }, - # ========================================================================= # TRANSIT GATEWAY CONTEXT COMMANDS # ========================================================================= @@ -213,21 +325,38 @@ def _vpns_list(): }, "transit-gateway.show route-tables": { "mock_target": "aws_network_tools.modules.tgw.TGWClient.discover", - "fixture_data": lambda: [{"id": "tgw-123", "route_tables": [ - {"id": rt["TransitGatewayRouteTableId"], "name": get_tag_value(rt), - "state": rt.get("State", "available"), "routes": []} - for rt in list(TGW_ROUTE_TABLE_FIXTURES.values())[:2] - ]}], + "fixture_data": lambda: [ + { + "id": "tgw-123", + "route_tables": [ + { + "id": rt["TransitGatewayRouteTableId"], + "name": get_tag_value(rt), + "state": rt.get("State", "available"), + "routes": [], + } + for rt in list(TGW_ROUTE_TABLE_FIXTURES.values())[:2] + ], + } + ], }, "transit-gateway.show attachments": { "mock_target": "aws_network_tools.modules.tgw.TGWClient.discover", - "fixture_data": lambda: [{"id": "tgw-123", "attachments": [ - {"id": att["TransitGatewayAttachmentId"], "type": att["ResourceType"], - "resource_id": att.get("ResourceId"), "state": att["State"]} - for att in list(TGW_ATTACHMENT_FIXTURES.values())[:3] - ]}], + "fixture_data": lambda: [ + { + "id": "tgw-123", + "attachments": [ + { + "id": att["TransitGatewayAttachmentId"], + "type": att["ResourceType"], + "resource_id": att.get("ResourceId"), + "state": att["State"], + } + for att in list(TGW_ATTACHMENT_FIXTURES.values())[:3] + ], + } + ], }, - # ========================================================================= # FIREWALL CONTEXT COMMANDS # ========================================================================= @@ -237,16 +366,24 @@ def _vpns_list(): }, "firewall.show rule-groups": { "mock_target": "aws_network_tools.modules.anfw.ANFWClient.discover", - "fixture_data": lambda: [{"name": "test-fw", "rule_groups": [ - {"name": rg["RuleGroupName"], "arn": rg["RuleGroupArn"], "type": rg["Type"]} - for rg in list(RULE_GROUP_FIXTURES.values())[:2] - ]}], + "fixture_data": lambda: [ + { + "name": "test-fw", + "rule_groups": [ + { + "name": rg["RuleGroupName"], + "arn": rg["RuleGroupArn"], + "type": rg["Type"], + } + for rg in list(RULE_GROUP_FIXTURES.values())[:2] + ], + } + ], }, "firewall.show policy": { "mock_target": "aws_network_tools.modules.anfw.ANFWClient.discover", "fixture_data": _firewalls_list, }, - # ========================================================================= # EC2 INSTANCE CONTEXT COMMANDS # ========================================================================= @@ -266,7 +403,6 @@ def _vpns_list(): "mock_target": "aws_network_tools.modules.ec2.EC2Client.discover", "fixture_data": _ec2_instances_list, }, - # ========================================================================= # ELB CONTEXT COMMANDS # ========================================================================= @@ -286,7 +422,6 @@ def _vpns_list(): "mock_target": "aws_network_tools.modules.elb.ELBClient.discover", "fixture_data": _elbs_list, }, - # ========================================================================= # VPN CONTEXT COMMANDS # ========================================================================= @@ -298,7 +433,6 @@ def _vpns_list(): "mock_target": "aws_network_tools.modules.vpn.VPNClient.discover", "fixture_data": _vpns_list, }, - # ========================================================================= # CORE NETWORK CONTEXT COMMANDS # ========================================================================= @@ -344,7 +478,10 @@ def _vpns_list(): "global-network.set core-network": ["show global-networks"], "core-network.set route-table": ["core-network.show route-tables"], "vpc.set route-table": ["show vpcs", "vpc.show route-tables"], - "transit-gateway.set route-table": ["show transit_gateways", "transit-gateway.show route-tables"], + "transit-gateway.set route-table": [ + "show transit_gateways", + "transit-gateway.show route-tables", + ], } @@ -353,10 +490,10 @@ def get_mock_for_command(command: str, **context_args) -> dict[str, Any] | None: config = COMMAND_MOCKS.get(command) if not config: return None - + fixture_fn = config["fixture_data"] return_value = fixture_fn() if callable(fixture_fn) else fixture_fn - + return { "target": config["mock_target"], "return_value": return_value, diff --git a/tests/fixtures/ec2.py b/tests/fixtures/ec2.py index 82de6d0..35b7a96 100644 --- a/tests/fixtures/ec2.py +++ b/tests/fixtures/ec2.py @@ -77,7 +77,7 @@ ], "IamInstanceProfile": { "Arn": "arn:aws:iam::123456789012:instance-profile/production-web-profile", - "Id": "AKIAIOSFODNN7EXAMPLE", + "Id": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret }, "Tags": [ {"Key": "Name", "Value": "production-web-1a"}, @@ -151,7 +151,7 @@ ], "IamInstanceProfile": { "Arn": "arn:aws:iam::123456789012:instance-profile/production-web-profile", - "Id": "AKIAIOSFODNN7EXAMPLE", + "Id": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret }, "Tags": [ {"Key": "Name", "Value": "production-web-1b"}, @@ -233,7 +233,7 @@ ], "IamInstanceProfile": { "Arn": "arn:aws:iam::123456789012:instance-profile/bastion-profile", - "Id": "AKIAIOSFODNN7EXAMPLE", + "Id": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret }, "Tags": [ {"Key": "Name", "Value": "shared-bastion-1a"}, @@ -306,7 +306,7 @@ ], "IamInstanceProfile": { "Arn": "arn:aws:iam::123456789012:instance-profile/staging-app-profile", - "Id": "AKIAIOSFODNN7EXAMPLE", + "Id": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret }, "Tags": [ {"Key": "Name", "Value": "staging-app-1a"}, @@ -368,7 +368,10 @@ "DeleteOnTermination": True, }, "Groups": [ - {"GroupId": "sg-0devall12345678901", "GroupName": "development-all-sg"} + { + "GroupId": "sg-0devall12345678901", + "GroupName": "development-all-sg", + } ], "SourceDestCheck": True, "Status": "in-use", @@ -455,7 +458,9 @@ "Status": "in-use", "SourceDestCheck": True, "InterfaceType": "network_load_balancer", - "Groups": [{"GroupId": "sg-0prodweb123456789", "GroupName": "production-web-alb-sg"}], + "Groups": [ + {"GroupId": "sg-0prodweb123456789", "GroupName": "production-web-alb-sg"} + ], "Attachment": { "AttachmentId": "ela-attach-0alb1a123", "DeviceIndex": 1, @@ -494,7 +499,9 @@ "Status": "in-use", "SourceDestCheck": True, "InterfaceType": "network_load_balancer", - "Groups": [{"GroupId": "sg-0prodweb123456789", "GroupName": "production-web-alb-sg"}], + "Groups": [ + {"GroupId": "sg-0prodweb123456789", "GroupName": "production-web-alb-sg"} + ], "Attachment": { "AttachmentId": "ela-attach-0alb1b123", "DeviceIndex": 1, @@ -596,7 +603,9 @@ "Status": "in-use", "SourceDestCheck": True, "InterfaceType": "interface", - "Groups": [{"GroupId": "sg-0proddb1234567890", "GroupName": "production-db-sg"}], + "Groups": [ + {"GroupId": "sg-0proddb1234567890", "GroupName": "production-db-sg"} + ], "Attachment": { "AttachmentId": "eni-attach-0rds1a1234", "DeviceIndex": 1, diff --git a/tests/fixtures/elb.py b/tests/fixtures/elb.py index 68fddbb..e4d3516 100644 --- a/tests/fixtures/elb.py +++ b/tests/fixtures/elb.py @@ -490,11 +490,15 @@ def get_elb_detail(elb_arn: str) -> dict[str, Any] | None: return None # Gather associated listeners - listeners = [l for l in LISTENER_FIXTURES.values() if l["LoadBalancerArn"] == elb_arn] + listeners = [ + lis for lis in LISTENER_FIXTURES.values() if lis["LoadBalancerArn"] == elb_arn + ] # Gather associated target groups target_groups = [ - tg for tg in TARGET_GROUP_FIXTURES.values() if elb_arn in tg.get("LoadBalancerArns", []) + tg + for tg in TARGET_GROUP_FIXTURES.values() + if elb_arn in tg.get("LoadBalancerArns", []) ] # Gather target health for each target group @@ -523,7 +527,9 @@ def get_target_health(tg_arn: str) -> list[dict[str, Any]]: def get_listeners_by_elb(elb_arn: str) -> list[dict[str, Any]]: """Get all listeners for a load balancer.""" - return [l for l in LISTENER_FIXTURES.values() if l["LoadBalancerArn"] == elb_arn] + return [ + lis for lis in LISTENER_FIXTURES.values() if lis["LoadBalancerArn"] == elb_arn + ] def get_target_groups_by_vpc(vpc_id: str) -> list[dict[str, Any]]: diff --git a/tests/fixtures/firewall.py b/tests/fixtures/firewall.py index c13f5e9..ba0a736 100644 --- a/tests/fixtures/firewall.py +++ b/tests/fixtures/firewall.py @@ -218,9 +218,13 @@ "RuleDefinition": { "MatchAttributes": { "Sources": [{"AddressDefinition": "10.0.0.0/16"}], - "Destinations": [{"AddressDefinition": "0.0.0.0/0"}], + "Destinations": [ + {"AddressDefinition": "0.0.0.0/0"} + ], "SourcePorts": [{"FromPort": 0, "ToPort": 65535}], - "DestinationPorts": [{"FromPort": 443, "ToPort": 443}], + "DestinationPorts": [ + {"FromPort": 443, "ToPort": 443} + ], "Protocols": [6], }, "Actions": ["aws:pass"], @@ -231,9 +235,13 @@ "RuleDefinition": { "MatchAttributes": { "Sources": [{"AddressDefinition": "10.0.0.0/16"}], - "Destinations": [{"AddressDefinition": "0.0.0.0/0"}], + "Destinations": [ + {"AddressDefinition": "0.0.0.0/0"} + ], "SourcePorts": [{"FromPort": 0, "ToPort": 65535}], - "DestinationPorts": [{"FromPort": 80, "ToPort": 80}], + "DestinationPorts": [ + {"FromPort": 80, "ToPort": 80} + ], "Protocols": [6], }, "Actions": ["aws:pass"], @@ -566,6 +574,4 @@ def get_rule_group_by_arn(rule_group_arn: str) -> dict[str, Any] | None: def get_firewalls_by_vpc(vpc_id: str) -> list[dict[str, Any]]: """Get all firewalls in a VPC.""" - return [ - fw for fw in NETWORK_FIREWALL_FIXTURES.values() if fw["VpcId"] == vpc_id - ] + return [fw for fw in NETWORK_FIREWALL_FIXTURES.values() if fw["VpcId"] == vpc_id] diff --git a/tests/fixtures/fixture_generator.py b/tests/fixtures/fixture_generator.py index 6ee63ab..003043c 100644 --- a/tests/fixtures/fixture_generator.py +++ b/tests/fixtures/fixture_generator.py @@ -367,9 +367,7 @@ def main(): "--resource", required=True, help="Resource type (vpc, nat-gateway, etc.)" ) parser.add_argument("--count", type=int, default=1, help="Number of fixtures") - parser.add_argument( - "--from-api", action="store_true", help="Fetch from AWS API" - ) + parser.add_argument("--from-api", action="store_true", help="Fetch from AWS API") parser.add_argument("--resource-id", help="AWS resource ID (for --from-api)") parser.add_argument( "--template-only", action="store_true", help="Generate file template only" @@ -409,7 +407,7 @@ def main(): # Validate errors = generator.validate_fixture(sanitized, args.resource) if errors: - print(f"\n⚠️ Validation warnings:") + print("\n⚠️ Validation warnings:") for error in errors: print(f" - {error}") else: diff --git a/tests/fixtures/gateways.py b/tests/fixtures/gateways.py index 244c4fa..c6ea952 100644 --- a/tests/fixtures/gateways.py +++ b/tests/fixtures/gateways.py @@ -633,10 +633,20 @@ def get_gateway_summary() -> dict[str, int]: "nat_gateways": len(NAT_GATEWAY_FIXTURES), "elastic_ips": len(EIP_FIXTURES), "egress_only_igws": len(EGRESS_ONLY_IGW_FIXTURES), - "nat_available": len([n for n in NAT_GATEWAY_FIXTURES.values() if n["State"] == "available"]), - "nat_pending": len([n for n in NAT_GATEWAY_FIXTURES.values() if n["State"] == "pending"]), - "nat_deleting": len([n for n in NAT_GATEWAY_FIXTURES.values() if n["State"] == "deleting"]), - "nat_failed": len([n for n in NAT_GATEWAY_FIXTURES.values() if n["State"] == "failed"]), - "eip_allocated": len([e for e in EIP_FIXTURES.values() if e.get("NetworkInterfaceId")]), + "nat_available": len( + [n for n in NAT_GATEWAY_FIXTURES.values() if n["State"] == "available"] + ), + "nat_pending": len( + [n for n in NAT_GATEWAY_FIXTURES.values() if n["State"] == "pending"] + ), + "nat_deleting": len( + [n for n in NAT_GATEWAY_FIXTURES.values() if n["State"] == "deleting"] + ), + "nat_failed": len( + [n for n in NAT_GATEWAY_FIXTURES.values() if n["State"] == "failed"] + ), + "eip_allocated": len( + [e for e in EIP_FIXTURES.values() if e.get("NetworkInterfaceId")] + ), "eip_unallocated": len(get_unallocated_eips()), } diff --git a/tests/fixtures/global_accelerator.py b/tests/fixtures/global_accelerator.py index 67d7125..1873917 100644 --- a/tests/fixtures/global_accelerator.py +++ b/tests/fixtures/global_accelerator.py @@ -350,9 +350,7 @@ def get_enabled_accelerators() -> list[dict[str, Any]]: List of enabled accelerators """ return [ - acc - for acc in GLOBAL_ACCELERATOR_FIXTURES.values() - if acc.get("Enabled", False) + acc for acc in GLOBAL_ACCELERATOR_FIXTURES.values() if acc.get("Enabled", False) ] @@ -366,7 +364,9 @@ def get_accelerators_by_status(status: str) -> list[dict[str, Any]]: List of accelerators with matching status """ return [ - acc for acc in GLOBAL_ACCELERATOR_FIXTURES.values() if acc.get("Status") == status + acc + for acc in GLOBAL_ACCELERATOR_FIXTURES.values() + if acc.get("Status") == status ] @@ -485,9 +485,7 @@ def get_accelerators_with_flow_logs() -> list[dict[str, Any]]: if attrs.get("FlowLogsEnabled", False) ] return [ - acc - for arn, acc in GLOBAL_ACCELERATOR_FIXTURES.items() - if arn in enabled_arns + acc for arn, acc in GLOBAL_ACCELERATOR_FIXTURES.items() if arn in enabled_arns ] diff --git a/tests/fixtures/global_network.py b/tests/fixtures/global_network.py index 2f59981..ba43079 100644 --- a/tests/fixtures/global_network.py +++ b/tests/fixtures/global_network.py @@ -73,7 +73,4 @@ def get_all_global_networks() -> list[dict[str, Any]]: def get_global_networks_by_state(state: str) -> list[dict[str, Any]]: """Get Global Networks by state.""" - return [ - gn for gn in GLOBAL_NETWORK_FIXTURES.values() - if gn.get("State") == state - ] + return [gn for gn in GLOBAL_NETWORK_FIXTURES.values() if gn.get("State") == state] diff --git a/tests/fixtures/peering.py b/tests/fixtures/peering.py index f5ce4cd..92f68eb 100644 --- a/tests/fixtures/peering.py +++ b/tests/fixtures/peering.py @@ -294,7 +294,8 @@ def get_peering_by_id(pcx_id: str) -> dict[str, Any] | None: def get_active_peerings() -> list[dict[str, Any]]: """Get all active VPC peering connections.""" return [ - pcx for pcx in VPC_PEERING_FIXTURES.values() + pcx + for pcx in VPC_PEERING_FIXTURES.values() if pcx["Status"]["Code"] == "active" ] @@ -305,15 +306,15 @@ def get_peerings_by_status(status: str) -> list[dict[str, Any]]: Valid statuses: active, pending-acceptance, provisioning, failed, deleting, deleted """ return [ - pcx for pcx in VPC_PEERING_FIXTURES.values() - if pcx["Status"]["Code"] == status + pcx for pcx in VPC_PEERING_FIXTURES.values() if pcx["Status"]["Code"] == status ] def get_peerings_by_vpc(vpc_id: str) -> list[dict[str, Any]]: """Get all VPC peering connections for a specific VPC (requester or accepter).""" return [ - pcx for pcx in VPC_PEERING_FIXTURES.values() + pcx + for pcx in VPC_PEERING_FIXTURES.values() if pcx["RequesterVpcInfo"]["VpcId"] == vpc_id or pcx["AccepterVpcInfo"]["VpcId"] == vpc_id ] @@ -322,7 +323,8 @@ def get_peerings_by_vpc(vpc_id: str) -> list[dict[str, Any]]: def get_intra_region_peerings(region: str) -> list[dict[str, Any]]: """Get intra-region VPC peering connections.""" return [ - pcx for pcx in VPC_PEERING_FIXTURES.values() + pcx + for pcx in VPC_PEERING_FIXTURES.values() if pcx["RequesterVpcInfo"]["Region"] == region and pcx["AccepterVpcInfo"]["Region"] == region ] @@ -331,7 +333,8 @@ def get_intra_region_peerings(region: str) -> list[dict[str, Any]]: def get_cross_region_peerings() -> list[dict[str, Any]]: """Get cross-region VPC peering connections.""" return [ - pcx for pcx in VPC_PEERING_FIXTURES.values() + pcx + for pcx in VPC_PEERING_FIXTURES.values() if pcx["RequesterVpcInfo"]["Region"] != pcx["AccepterVpcInfo"]["Region"] ] @@ -339,7 +342,8 @@ def get_cross_region_peerings() -> list[dict[str, Any]]: def get_cross_account_peerings() -> list[dict[str, Any]]: """Get cross-account VPC peering connections.""" return [ - pcx for pcx in VPC_PEERING_FIXTURES.values() + pcx + for pcx in VPC_PEERING_FIXTURES.values() if pcx["RequesterVpcInfo"]["OwnerId"] != pcx["AccepterVpcInfo"]["OwnerId"] ] @@ -347,16 +351,22 @@ def get_cross_account_peerings() -> list[dict[str, Any]]: def get_peerings_with_dns_resolution() -> list[dict[str, Any]]: """Get VPC peering connections with DNS resolution enabled.""" return [ - pcx for pcx in VPC_PEERING_FIXTURES.values() - if (pcx["RequesterVpcInfo"]["PeeringOptions"]["AllowDnsResolutionFromRemoteVpc"] - or pcx["AccepterVpcInfo"]["PeeringOptions"]["AllowDnsResolutionFromRemoteVpc"]) + pcx + for pcx in VPC_PEERING_FIXTURES.values() + if ( + pcx["RequesterVpcInfo"]["PeeringOptions"]["AllowDnsResolutionFromRemoteVpc"] + or pcx["AccepterVpcInfo"]["PeeringOptions"][ + "AllowDnsResolutionFromRemoteVpc" + ] + ) ] def get_peerings_by_owner(owner_id: str) -> list[dict[str, Any]]: """Get VPC peering connections where account is requester or accepter.""" return [ - pcx for pcx in VPC_PEERING_FIXTURES.values() + pcx + for pcx in VPC_PEERING_FIXTURES.values() if pcx["RequesterVpcInfo"]["OwnerId"] == owner_id or pcx["AccepterVpcInfo"]["OwnerId"] == owner_id ] diff --git a/tests/fixtures/route53_resolver.py b/tests/fixtures/route53_resolver.py index 6940d41..3a33a4b 100644 --- a/tests/fixtures/route53_resolver.py +++ b/tests/fixtures/route53_resolver.py @@ -521,23 +521,22 @@ def get_resolver_endpoint_by_id(endpoint_id: str) -> dict[str, Any] | None: def get_resolver_endpoints_by_vpc(vpc_id: str) -> list[dict[str, Any]]: """Get all resolver endpoints for a specific VPC.""" return [ - ep for ep in RESOLVER_ENDPOINT_FIXTURES.values() - if ep["HostVPCId"] == vpc_id + ep for ep in RESOLVER_ENDPOINT_FIXTURES.values() if ep["HostVPCId"] == vpc_id ] def get_resolver_endpoints_by_direction(direction: str) -> list[dict[str, Any]]: """Get resolver endpoints by direction (INBOUND or OUTBOUND).""" return [ - ep for ep in RESOLVER_ENDPOINT_FIXTURES.values() - if ep["Direction"] == direction + ep for ep in RESOLVER_ENDPOINT_FIXTURES.values() if ep["Direction"] == direction ] def get_operational_endpoints() -> list[dict[str, Any]]: """Get all operational resolver endpoints.""" return [ - ep for ep in RESOLVER_ENDPOINT_FIXTURES.values() + ep + for ep in RESOLVER_ENDPOINT_FIXTURES.values() if ep["Status"] == "OPERATIONAL" ] @@ -555,7 +554,8 @@ def get_resolver_rule_by_id(rule_id: str) -> dict[str, Any] | None: def get_resolver_rules_by_type(rule_type: str) -> list[dict[str, Any]]: """Get resolver rules by type (FORWARD, SYSTEM, RECURSIVE).""" return [ - rule for rule in RESOLVER_RULE_FIXTURES.values() + rule + for rule in RESOLVER_RULE_FIXTURES.values() if rule["RuleType"] == rule_type ] @@ -568,7 +568,8 @@ def get_forward_rules() -> list[dict[str, Any]]: def get_resolver_rules_by_endpoint(endpoint_id: str) -> list[dict[str, Any]]: """Get all resolver rules associated with an endpoint.""" return [ - rule for rule in RESOLVER_RULE_FIXTURES.values() + rule + for rule in RESOLVER_RULE_FIXTURES.values() if rule.get("ResolverEndpointId") == endpoint_id ] @@ -599,7 +600,8 @@ def get_query_log_config_by_id(config_id: str) -> dict[str, Any] | None: def get_query_log_configs_by_status(status: str) -> list[dict[str, Any]]: """Get query logging configurations by status (CREATED, CREATING, DELETING).""" return [ - config for config in QUERY_LOG_CONFIG_FIXTURES.values() + config + for config in QUERY_LOG_CONFIG_FIXTURES.values() if config["Status"] == status ] diff --git a/tests/fixtures/tgw.py b/tests/fixtures/tgw.py index 96bf11c..6750bec 100644 --- a/tests/fixtures/tgw.py +++ b/tests/fixtures/tgw.py @@ -751,7 +751,9 @@ def get_tgw_detail(tgw_id: str) -> dict[str, Any] | None: # Gather associated route tables route_tables = [ - rt for rt in TGW_ROUTE_TABLE_FIXTURES.values() if rt["TransitGatewayId"] == tgw_id + rt + for rt in TGW_ROUTE_TABLE_FIXTURES.values() + if rt["TransitGatewayId"] == tgw_id ] # Gather associated peerings @@ -772,12 +774,18 @@ def get_tgw_detail(tgw_id: str) -> dict[str, Any] | None: def get_attachments_by_tgw(tgw_id: str) -> list[dict[str, Any]]: """Get all attachments for a Transit Gateway.""" - return [a for a in TGW_ATTACHMENT_FIXTURES.values() if a["TransitGatewayId"] == tgw_id] + return [ + a for a in TGW_ATTACHMENT_FIXTURES.values() if a["TransitGatewayId"] == tgw_id + ] def get_route_tables_by_tgw(tgw_id: str) -> list[dict[str, Any]]: """Get all route tables for a Transit Gateway.""" - return [rt for rt in TGW_ROUTE_TABLE_FIXTURES.values() if rt["TransitGatewayId"] == tgw_id] + return [ + rt + for rt in TGW_ROUTE_TABLE_FIXTURES.values() + if rt["TransitGatewayId"] == tgw_id + ] def get_attachment_by_id(attachment_id: str) -> dict[str, Any] | None: diff --git a/tests/fixtures/vpc.py b/tests/fixtures/vpc.py index ebb65bc..7e94735 100644 --- a/tests/fixtures/vpc.py +++ b/tests/fixtures/vpc.py @@ -665,14 +665,19 @@ "FromPort": 80, "ToPort": 80, "IpRanges": [ - {"CidrIp": "0.0.0.0/0", "Description": "HTTP redirect from internet"} + { + "CidrIp": "0.0.0.0/0", + "Description": "HTTP redirect from internet", + } ], }, ], "IpPermissionsEgress": [ { "IpProtocol": "-1", - "IpRanges": [{"CidrIp": "0.0.0.0/0", "Description": "Allow all outbound"}], + "IpRanges": [ + {"CidrIp": "0.0.0.0/0", "Description": "Allow all outbound"} + ], }, ], "Tags": [ @@ -704,14 +709,19 @@ "FromPort": 22, "ToPort": 22, "IpRanges": [ - {"CidrIp": "10.100.0.0/16", "Description": "SSH from shared services"} + { + "CidrIp": "10.100.0.0/16", + "Description": "SSH from shared services", + } ], }, ], "IpPermissionsEgress": [ { "IpProtocol": "-1", - "IpRanges": [{"CidrIp": "0.0.0.0/0", "Description": "Allow all outbound"}], + "IpRanges": [ + {"CidrIp": "0.0.0.0/0", "Description": "Allow all outbound"} + ], }, ], "Tags": [ @@ -743,14 +753,19 @@ "FromPort": 5432, "ToPort": 5432, "IpRanges": [ - {"CidrIp": "10.100.10.0/24", "Description": "PostgreSQL from bastion"} + { + "CidrIp": "10.100.10.0/24", + "Description": "PostgreSQL from bastion", + } ], }, ], "IpPermissionsEgress": [ { "IpProtocol": "-1", - "IpRanges": [{"CidrIp": "0.0.0.0/0", "Description": "Allow all outbound"}], + "IpRanges": [ + {"CidrIp": "0.0.0.0/0", "Description": "Allow all outbound"} + ], }, ], "Tags": [ @@ -771,7 +786,10 @@ "FromPort": 22, "ToPort": 22, "IpRanges": [ - {"CidrIp": "203.0.113.0/24", "Description": "SSH from corporate IP range"} + { + "CidrIp": "203.0.113.0/24", + "Description": "SSH from corporate IP range", + } ], }, ], @@ -781,7 +799,10 @@ "FromPort": 22, "ToPort": 22, "IpRanges": [ - {"CidrIp": "10.0.0.0/8", "Description": "SSH to all internal networks"} + { + "CidrIp": "10.0.0.0/8", + "Description": "SSH to all internal networks", + } ], }, { @@ -808,19 +829,25 @@ "IpProtocol": "tcp", "FromPort": 8080, "ToPort": 8080, - "IpRanges": [{"CidrIp": "0.0.0.0/0", "Description": "HTTP from anywhere"}], + "IpRanges": [ + {"CidrIp": "0.0.0.0/0", "Description": "HTTP from anywhere"} + ], }, { "IpProtocol": "tcp", "FromPort": 22, "ToPort": 22, - "IpRanges": [{"CidrIp": "10.0.0.0/8", "Description": "SSH from internal"}], + "IpRanges": [ + {"CidrIp": "10.0.0.0/8", "Description": "SSH from internal"} + ], }, ], "IpPermissionsEgress": [ { "IpProtocol": "-1", - "IpRanges": [{"CidrIp": "0.0.0.0/0", "Description": "Allow all outbound"}], + "IpRanges": [ + {"CidrIp": "0.0.0.0/0", "Description": "Allow all outbound"} + ], }, ], "Tags": [ @@ -837,13 +864,17 @@ "IpPermissions": [ { "IpProtocol": "-1", - "IpRanges": [{"CidrIp": "10.0.0.0/8", "Description": "All from internal"}], + "IpRanges": [ + {"CidrIp": "10.0.0.0/8", "Description": "All from internal"} + ], }, ], "IpPermissionsEgress": [ { "IpProtocol": "-1", - "IpRanges": [{"CidrIp": "0.0.0.0/0", "Description": "Allow all outbound"}], + "IpRanges": [ + {"CidrIp": "0.0.0.0/0", "Description": "Allow all outbound"} + ], }, ], "Tags": [ @@ -1141,7 +1172,9 @@ def get_vpc_detail(vpc_id: str) -> dict[str, Any] | None: route_tables = [rt for rt in ROUTE_TABLE_FIXTURES.values() if rt["VpcId"] == vpc_id] # Gather associated security groups - security_groups = [sg for sg in SECURITY_GROUP_FIXTURES.values() if sg["VpcId"] == vpc_id] + security_groups = [ + sg for sg in SECURITY_GROUP_FIXTURES.values() if sg["VpcId"] == vpc_id + ] # Gather associated NACLs nacls = [nacl for nacl in NACL_FIXTURES.values() if nacl["VpcId"] == vpc_id] diff --git a/tests/fixtures/vpc_endpoints.py b/tests/fixtures/vpc_endpoints.py index 6cfc50a..29fae6b 100644 --- a/tests/fixtures/vpc_endpoints.py +++ b/tests/fixtures/vpc_endpoints.py @@ -68,7 +68,9 @@ "DnsName": "vpce-0prods3iface123456-abc123-eu-west-1c.s3.eu-west-1.vpce.amazonaws.com", "HostedZoneId": "Z7HUB22UULQXV", }, - {"DnsName": "bucket.vpce-0prods3iface123456-abc123.s3.eu-west-1.vpce.amazonaws.com"}, + { + "DnsName": "bucket.vpce-0prods3iface123456-abc123.s3.eu-west-1.vpce.amazonaws.com" + }, {"DnsName": "s3.eu-west-1.amazonaws.com"}, ], "CreationTimestamp": datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc), @@ -302,7 +304,9 @@ "PolicyDocument": None, "RouteTableIds": [], "SubnetIds": ["subnet-0devpriv2a1234567"], - "Groups": [{"GroupId": "sg-0devall12345678901", "GroupName": "development-all-sg"}], + "Groups": [ + {"GroupId": "sg-0devall12345678901", "GroupName": "development-all-sg"} + ], "PrivateDnsEnabled": True, "RequesterManaged": False, "NetworkInterfaceIds": ["eni-0devlambep2a123456"], @@ -615,7 +619,9 @@ "VpcId": "vpc-0prod1234567890ab", "AvailabilityZone": "eu-west-1a", "Description": "VPC Endpoint Interface vpce-0prods3iface123456", - "Groups": [{"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"}], + "Groups": [ + {"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"} + ], "PrivateIpAddress": "10.0.10.100", "PrivateIpAddresses": [ { @@ -626,7 +632,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0prods3iface123456", }, "eni-0prods3ep1b1234567": { @@ -635,7 +641,9 @@ "VpcId": "vpc-0prod1234567890ab", "AvailabilityZone": "eu-west-1b", "Description": "VPC Endpoint Interface vpce-0prods3iface123456", - "Groups": [{"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"}], + "Groups": [ + {"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"} + ], "PrivateIpAddress": "10.0.11.100", "PrivateIpAddresses": [ { @@ -646,7 +654,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0prods3iface123456", }, "eni-0prods3ep1c1234567": { @@ -655,7 +663,9 @@ "VpcId": "vpc-0prod1234567890ab", "AvailabilityZone": "eu-west-1c", "Description": "VPC Endpoint Interface vpce-0prods3iface123456", - "Groups": [{"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"}], + "Groups": [ + {"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"} + ], "PrivateIpAddress": "10.0.12.100", "PrivateIpAddresses": [ { @@ -666,7 +676,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0prods3iface123456", }, # Production DynamoDB Interface Endpoint ENIs @@ -676,7 +686,9 @@ "VpcId": "vpc-0prod1234567890ab", "AvailabilityZone": "eu-west-1a", "Description": "VPC Endpoint Interface vpce-0proddynamoiface12", - "Groups": [{"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"}], + "Groups": [ + {"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"} + ], "PrivateIpAddress": "10.0.10.101", "PrivateIpAddresses": [ { @@ -687,7 +699,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0proddynamoiface12", }, "eni-0proddynep1b123456": { @@ -696,7 +708,9 @@ "VpcId": "vpc-0prod1234567890ab", "AvailabilityZone": "eu-west-1b", "Description": "VPC Endpoint Interface vpce-0proddynamoiface12", - "Groups": [{"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"}], + "Groups": [ + {"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"} + ], "PrivateIpAddress": "10.0.11.101", "PrivateIpAddresses": [ { @@ -707,7 +721,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0proddynamoiface12", }, # Production Lambda Interface Endpoint ENIs @@ -717,7 +731,9 @@ "VpcId": "vpc-0prod1234567890ab", "AvailabilityZone": "eu-west-1a", "Description": "VPC Endpoint Interface vpce-0prodlambdaiface12", - "Groups": [{"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"}], + "Groups": [ + {"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"} + ], "PrivateIpAddress": "10.0.10.102", "PrivateIpAddresses": [ { @@ -728,7 +744,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0prodlambdaiface12", }, "eni-0prodlambep1b12345": { @@ -737,7 +753,9 @@ "VpcId": "vpc-0prod1234567890ab", "AvailabilityZone": "eu-west-1b", "Description": "VPC Endpoint Interface vpce-0prodlambdaiface12", - "Groups": [{"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"}], + "Groups": [ + {"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"} + ], "PrivateIpAddress": "10.0.11.102", "PrivateIpAddresses": [ { @@ -748,7 +766,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0prodlambdaiface12", }, "eni-0prodlambep1c12345": { @@ -757,7 +775,9 @@ "VpcId": "vpc-0prod1234567890ab", "AvailabilityZone": "eu-west-1c", "Description": "VPC Endpoint Interface vpce-0prodlambdaiface12", - "Groups": [{"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"}], + "Groups": [ + {"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"} + ], "PrivateIpAddress": "10.0.12.102", "PrivateIpAddresses": [ { @@ -768,7 +788,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0prodlambdaiface12", }, # Production EC2 API Interface Endpoint ENIs @@ -778,7 +798,9 @@ "VpcId": "vpc-0prod1234567890ab", "AvailabilityZone": "eu-west-1a", "Description": "VPC Endpoint Interface vpce-0prodec2apiface123", - "Groups": [{"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"}], + "Groups": [ + {"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"} + ], "PrivateIpAddress": "10.0.10.103", "PrivateIpAddresses": [ { @@ -789,7 +811,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0prodec2apiface123", }, "eni-0prodec2ep1b123456": { @@ -798,7 +820,9 @@ "VpcId": "vpc-0prod1234567890ab", "AvailabilityZone": "eu-west-1b", "Description": "VPC Endpoint Interface vpce-0prodec2apiface123", - "Groups": [{"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"}], + "Groups": [ + {"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"} + ], "PrivateIpAddress": "10.0.11.103", "PrivateIpAddresses": [ { @@ -809,7 +833,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0prodec2apiface123", }, # Shared Services - Custom App Service ENIs @@ -819,7 +843,9 @@ "VpcId": "vpc-0prod1234567890ab", "AvailabilityZone": "eu-west-1a", "Description": "VPC Endpoint Interface vpce-0sharedappservice12", - "Groups": [{"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"}], + "Groups": [ + {"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"} + ], "PrivateIpAddress": "10.0.10.150", "PrivateIpAddresses": [ { @@ -830,7 +856,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0sharedappservice12", }, "eni-0sharedapp1b123456": { @@ -839,7 +865,9 @@ "VpcId": "vpc-0prod1234567890ab", "AvailabilityZone": "eu-west-1b", "Description": "VPC Endpoint Interface vpce-0sharedappservice12", - "Groups": [{"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"}], + "Groups": [ + {"GroupId": "sg-0prodapp123456789", "GroupName": "production-app-sg"} + ], "PrivateIpAddress": "10.0.11.150", "PrivateIpAddresses": [ { @@ -850,7 +878,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0sharedappservice12", }, # Staging S3 Interface Endpoint ENI (pending state) @@ -871,7 +899,7 @@ "RequesterManaged": True, "Status": "in-use", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0stags3iface123456", }, # Development Lambda Interface Endpoint ENI (deleting state) @@ -881,7 +909,9 @@ "VpcId": "vpc-0dev01234567890ab", "AvailabilityZone": "ap-southeast-2a", "Description": "VPC Endpoint Interface vpce-0devlambdaiface123", - "Groups": [{"GroupId": "sg-0devall12345678901", "GroupName": "development-all-sg"}], + "Groups": [ + {"GroupId": "sg-0devall12345678901", "GroupName": "development-all-sg"} + ], "PrivateIpAddress": "10.2.10.100", "PrivateIpAddresses": [ { @@ -892,7 +922,7 @@ "RequesterManaged": True, "Status": "available", "InterfaceType": "vpc_endpoint", - "RequesterId": "AKIAIOSFODNN7EXAMPLE", + "RequesterId": "AKIAIOSFODNN7EXAMPLE", # pragma: allowlist secret "VpcEndpointId": "vpce-0devlambdaiface123", }, } @@ -927,9 +957,7 @@ def get_gateway_endpoints_by_vpc(vpc_id: str) -> list[dict[str, Any]]: def get_all_endpoints_by_vpc(vpc_id: str) -> list[dict[str, Any]]: """Get all VPC endpoints (interface and gateway) in a VPC.""" - return get_interface_endpoints_by_vpc(vpc_id) + get_gateway_endpoints_by_vpc( - vpc_id - ) + return get_interface_endpoints_by_vpc(vpc_id) + get_gateway_endpoints_by_vpc(vpc_id) def get_endpoint_service_by_id(service_id: str) -> dict[str, Any] | None: diff --git a/tests/generate_report.py b/tests/generate_report.py index 4dbc4ca..db6102b 100755 --- a/tests/generate_report.py +++ b/tests/generate_report.py @@ -29,9 +29,9 @@ def generate_summary_section(report: dict) -> str: |--------|-------| | **Timestamp** | {timestamp} | | **Profile** | {profile} | -| **Total Tests** | {summary.get('total', 0)} | -| **Passed** | ✅ {summary.get('passed', 0)} | -| **Failed** | ❌ {summary.get('failed', 0)} | +| **Total Tests** | {summary.get("total", 0)} | +| **Passed** | ✅ {summary.get("passed", 0)} | +| **Failed** | ❌ {summary.get("failed", 0)} | | **Pass Rate** | {_calc_pass_rate(summary)}% | """ @@ -72,7 +72,9 @@ def generate_results_by_phase(results: list[dict]) -> str: status = "✅ PASS" if r["passed"] else "❌ FAIL" details = "; ".join(r.get("details", [])[:2]) # First 2 details details = details[:50] + "..." if len(details) > 50 else details - command = r["command"][:30] + "..." if len(r["command"]) > 30 else r["command"] + command = ( + r["command"][:30] + "..." if len(r["command"]) > 30 else r["command"] + ) output.append(f"| {r['test_id']} | `{command}` | {status} | {details} |") output.append("") @@ -149,28 +151,36 @@ def generate_recommendations(results: list[dict]) -> str: recommendations.append("The following tests showed count discrepancies:") for tid in error_types["count_mismatch"]: recommendations.append(f"- {tid}") - recommendations.append("\n**Fix**: Check if shell is filtering resources differently than AWS CLI.\n") + recommendations.append( + "\n**Fix**: Check if shell is filtering resources differently than AWS CLI.\n" + ) if error_types["missing_ids"]: recommendations.append("### Missing Resource IDs") recommendations.append("The following tests had missing resource IDs:") for tid in error_types["missing_ids"]: recommendations.append(f"- {tid}") - recommendations.append("\n**Fix**: Verify shell is querying all regions/resources.\n") + recommendations.append( + "\n**Fix**: Verify shell is querying all regions/resources.\n" + ) if error_types["execution_error"]: recommendations.append("### Execution Errors") recommendations.append("The following tests had execution errors:") for tid in error_types["execution_error"]: recommendations.append(f"- {tid}") - recommendations.append("\n**Fix**: Check command syntax and handler implementation.\n") + recommendations.append( + "\n**Fix**: Check command syntax and handler implementation.\n" + ) if error_types["output_error"]: recommendations.append("### Output Errors") recommendations.append("The following tests had errors in output:") for tid in error_types["output_error"]: recommendations.append(f"- {tid}") - recommendations.append("\n**Fix**: Review handler error handling and edge cases.\n") + recommendations.append( + "\n**Fix**: Review handler error handling and edge cases.\n" + ) return "\n".join(recommendations) @@ -219,7 +229,9 @@ def generate_markdown_report(report: dict) -> str: def main(): - parser = argparse.ArgumentParser(description="Generate test report from JSON results") + parser = argparse.ArgumentParser( + description="Generate test report from JSON results" + ) parser.add_argument( "--input", default="/tmp/test_results.json", diff --git a/tests/integration/test_github_issues.py b/tests/integration/test_github_issues.py index 57423d6..f615048 100644 --- a/tests/integration/test_github_issues.py +++ b/tests/integration/test_github_issues.py @@ -17,11 +17,11 @@ """ import pytest -import os # Try importing pexpect, skip all tests if not available try: import pexpect + PEXPECT_AVAILABLE = True except ImportError: PEXPECT_AVAILABLE = False @@ -46,34 +46,34 @@ class TestIssue10_ELB_NoOutput: def test_issue_10_show_listeners(self): """Binary: show listeners should return listener data, not 'No listeners'.""" # Spawn real shell - child = pexpect.spawn('aws-net-shell', timeout=10) + child = pexpect.spawn("aws-net-shell", timeout=10) try: # Wait for prompt - child.expect('aws-net>') + child.expect("aws-net>") # Replicate exact user workflow from issue - child.sendline('set elb Github-ALB') - child.expect('aws-net/el:Github-ALB>') + child.sendline("set elb Github-ALB") + child.expect("aws-net/el:Github-ALB>") # Verify detail works (baseline) - child.sendline('show detail') - child.expect('aws-net/el:Github-ALB>') - detail_output = child.before.decode('utf-8') - assert 'Load Balancer: Github-ALB' in detail_output + child.sendline("show detail") + child.expect("aws-net/el:Github-ALB>") + detail_output = child.before.decode("utf-8") + assert "Load Balancer: Github-ALB" in detail_output # CRITICAL TEST: show listeners - child.sendline('show listeners') - child.expect('aws-net/el:Github-ALB>') - listeners_output = child.before.decode('utf-8') + child.sendline("show listeners") + child.expect("aws-net/el:Github-ALB>") + listeners_output = child.before.decode("utf-8") # Binary FAIL condition from issue - assert 'No listeners' not in listeners_output, ( - f"Issue #10 REPRODUCED: show listeners returned 'No listeners'" + assert "No listeners" not in listeners_output, ( + "Issue #10 REPRODUCED: show listeners returned 'No listeners'" ) # Binary PASS condition - assert ('Listener' in listeners_output or 'Port' in listeners_output), ( + assert "Listener" in listeners_output or "Port" in listeners_output, ( f"Expected listener data in output:\n{listeners_output}" ) @@ -96,33 +96,33 @@ class TestIssue9_EC2_AllENIs: @pytest.mark.issue_9 def test_issue_9_show_enis_filtered(self): """Binary: show enis should return ONLY instance ENI, not all account ENIs.""" - child = pexpect.spawn('aws-net-shell', timeout=10) + child = pexpect.spawn("aws-net-shell", timeout=10) try: - child.expect('aws-net>') + child.expect("aws-net>") # Replicate exact user workflow - child.sendline('set ec2-instance i-011280e2844a5f00d') - child.expect('aws-net/ec:AWS-Github>') + child.sendline("set ec2-instance i-011280e2844a5f00d") + child.expect("aws-net/ec:AWS-Github>") # Verify detail works - child.sendline('show detail') - child.expect('aws-net/ec:AWS-Github>') - detail_output = child.before.decode('utf-8') - assert 'i-011280e2844a5f00d' in detail_output + child.sendline("show detail") + child.expect("aws-net/ec:AWS-Github>") + detail_output = child.before.decode("utf-8") + assert "i-011280e2844a5f00d" in detail_output # CRITICAL TEST: show enis should be filtered - child.sendline('show enis') - child.expect('aws-net/ec:AWS-Github>') - enis_output = child.before.decode('utf-8') + child.sendline("show enis") + child.expect("aws-net/ec:AWS-Github>") + enis_output = child.before.decode("utf-8") # Binary PASS: Should show instance's ENI - assert 'eni-0989f6e6ce4dfc707' in enis_output, ( + assert "eni-0989f6e6ce4dfc707" in enis_output, ( "Instance ENI not found in output" ) # Binary FAIL: Count ENIs (should be 1, not 150+) - eni_count = enis_output.count('eni-') + eni_count = enis_output.count("eni-") # From issue, user sees 150+ ENIs when there should be 1 assert eni_count <= 5, ( diff --git a/tests/integration/test_issue_9_10_simple.py b/tests/integration/test_issue_9_10_simple.py index dbdd356..8bb69da 100644 --- a/tests/integration/test_issue_9_10_simple.py +++ b/tests/integration/test_issue_9_10_simple.py @@ -8,7 +8,6 @@ import pytest import os -import sys # Import pexpect (required dependency) import pexpect @@ -23,29 +22,30 @@ def shell_process(): # Check if aws-net-shell command exists try: import subprocess - subprocess.run(['which', 'aws-net-shell'], check=True, capture_output=True) - except: + + subprocess.run(["which", "aws-net-shell"], check=True, capture_output=True) + except (subprocess.CalledProcessError, FileNotFoundError): pytest.skip("aws-net-shell CLI not installed - run: pip install -e .") # Set AWS profile if provided env = os.environ.copy() - if 'AWS_PROFILE' not in env: - env['AWS_PROFILE'] = 'taylaand+net-dev-Admin' # Default test profile + if "AWS_PROFILE" not in env: + env["AWS_PROFILE"] = "taylaand+net-dev-Admin" # Default test profile # Spawn interactive shell process - child = pexpect.spawn('aws-net-shell', timeout=15, encoding='utf-8', env=env) - + child = pexpect.spawn("aws-net-shell", timeout=15, encoding="utf-8", env=env) + try: # Wait for initial prompt - child.expect('aws-net>', timeout=5) + child.expect("aws-net>", timeout=5) yield child finally: # Cleanup if child.isalive(): - child.sendline('exit') + child.sendline("exit") try: child.expect(pexpect.EOF, timeout=2) - except: + except (pexpect.TIMEOUT, pexpect.EOF): child.terminate(force=True) @@ -53,69 +53,69 @@ def shell_process(): @pytest.mark.issue_10 class TestIssue10_ELB_NoOutput: """GitHub Issue #10: ELB show commands return no data. - + User workflow: 1. set elb Github-ALB 2. show detail (works) 3. show listeners (FAILS - returns "No listeners") - 4. show targets (FAILS - returns "No target groups") + 4. show targets (FAILS - returns "No target groups") 5. show health (FAILS - returns "No health data") """ - + def test_elb_show_listeners_has_data(self, shell_process): """Binary: show listeners should return listener data.""" child = shell_process - + # Set ELB context - child.sendline('set elb Github-ALB') - child.expect(['aws-net.*>', pexpect.TIMEOUT], timeout=10) - + child.sendline("set elb Github-ALB") + child.expect(["aws-net.*>", pexpect.TIMEOUT], timeout=10) + # Verify we're in ELB context (may show error if ELB doesn't exist) - output = child.before - + _output = child.before # Not asserted, just ensuring command completes + # Show listeners - CRITICAL TEST - child.sendline('show listeners') - child.expect(['aws-net.*>', pexpect.TIMEOUT], timeout=10) + child.sendline("show listeners") + child.expect(["aws-net.*>", pexpect.TIMEOUT], timeout=10) listeners_output = child.before - + print(f"\n[Issue #10 Test] Listeners output:\n{listeners_output}") - + # Binary assertion from issue - assert 'No listeners' not in listeners_output, ( + assert "No listeners" not in listeners_output, ( "Issue #10 CONFIRMED: show listeners returned 'No listeners'" ) - + def test_elb_show_targets_has_data(self, shell_process): """Binary: show targets should return target group data.""" child = shell_process - - child.sendline('set elb Github-ALB') - child.expect(['aws-net.*>', pexpect.TIMEOUT], timeout=10) - - child.sendline('show targets') - child.expect(['aws-net.*>', pexpect.TIMEOUT], timeout=10) + + child.sendline("set elb Github-ALB") + child.expect(["aws-net.*>", pexpect.TIMEOUT], timeout=10) + + child.sendline("show targets") + child.expect(["aws-net.*>", pexpect.TIMEOUT], timeout=10) targets_output = child.before - + print(f"\n[Issue #10 Test] Targets output:\n{targets_output}") - - assert 'No target groups' not in targets_output, ( + + assert "No target groups" not in targets_output, ( "Issue #10 CONFIRMED: show targets returned 'No target groups'" ) - + def test_elb_show_health_has_data(self, shell_process): """Binary: show health should return health check data.""" child = shell_process - - child.sendline('set elb Github-ALB') - child.expect(['aws-net.*>', pexpect.TIMEOUT], timeout=10) - - child.sendline('show health') - child.expect(['aws-net.*>', pexpect.TIMEOUT], timeout=10) + + child.sendline("set elb Github-ALB") + child.expect(["aws-net.*>", pexpect.TIMEOUT], timeout=10) + + child.sendline("show health") + child.expect(["aws-net.*>", pexpect.TIMEOUT], timeout=10) health_output = child.before - + print(f"\n[Issue #10 Test] Health output:\n{health_output}") - - assert 'No health data' not in health_output, ( + + assert "No health data" not in health_output, ( "Issue #10 CONFIRMED: show health returned 'No health data'" ) @@ -124,41 +124,41 @@ def test_elb_show_health_has_data(self, shell_process): @pytest.mark.issue_9 class TestIssue9_EC2_ReturnsAllENIs: """GitHub Issue #9: EC2 context returns ALL ENIs, not instance-specific. - + User workflow: 1. set ec2-instance i-011280e2844a5f00d 2. show detail (works) 3. show enis (FAILS - shows 150+ ENIs instead of 1) """ - + def test_ec2_show_enis_filtered_to_instance(self, shell_process): """Binary: show enis should return ONLY instance ENIs.""" child = shell_process - + # Set EC2 context - child.sendline('set ec2-instance i-011280e2844a5f00d') - child.expect(['aws-net.*>', pexpect.TIMEOUT], timeout=10) - + child.sendline("set ec2-instance i-011280e2844a5f00d") + child.expect(["aws-net.*>", pexpect.TIMEOUT], timeout=10) + # Show ENIs - CRITICAL TEST - child.sendline('show enis') - child.expect(['aws-net.*>', pexpect.TIMEOUT], timeout=15) + child.sendline("show enis") + child.expect(["aws-net.*>", pexpect.TIMEOUT], timeout=15) enis_output = child.before - + print(f"\n[Issue #9 Test] ENIs output (first 500 chars):\n{enis_output[:500]}") - + # Count ENIs in output - eni_count = enis_output.count('eni-') + eni_count = enis_output.count("eni-") print(f"\n[Issue #9 Test] Total ENI count: {eni_count}") - + # Binary assertion: Instance has 1 ENI, should NOT see 50+ ENIs assert eni_count <= 5, ( f"Issue #9 CONFIRMED: show enis returned {eni_count} ENIs. " f"Expected 1-2 (instance-specific), got {eni_count} (all account ENIs)" ) - + # Positive assertion: Should show the instance's actual ENI # eni-0989f6e6ce4dfc707 from issue description - assert 'eni-0989f6e6ce4dfc707' in enis_output or eni_count <= 2, ( + assert "eni-0989f6e6ce4dfc707" in enis_output or eni_count <= 2, ( "Expected instance-specific ENI not found" ) diff --git a/tests/integration/test_workflows.py b/tests/integration/test_workflows.py index c0d1340..a55fc80 100644 --- a/tests/integration/test_workflows.py +++ b/tests/integration/test_workflows.py @@ -31,7 +31,7 @@ def load_workflows(): for yaml_file in workflow_dir.glob("*.yaml"): with open(yaml_file) as f: workflow = yaml.safe_load(f) - workflow['_source_file'] = yaml_file.name + workflow["_source_file"] = yaml_file.name workflows.append(workflow) return workflows @@ -59,18 +59,18 @@ def test_workflow(self, workflow): runner.start() # Execute each step in workflow - for step in workflow.get('workflow', []): - command = step['command'] + for step in workflow.get("workflow", []): + command = step["command"] output = runner.run(command) # Check expect_contains if specified - if 'expect_contains' in step: - assert step['expect_contains'] in output, ( + if "expect_contains" in step: + assert step["expect_contains"] in output, ( f"Step '{command}' expected '{step['expect_contains']}' in output" ) # Run assertions if specified - for assertion in step.get('assertions', []): + for assertion in step.get("assertions", []): self._evaluate_assertion(assertion, output, command) finally: @@ -85,42 +85,42 @@ def _evaluate_assertion(self, assertion: dict, output: str, command: str): - contains_any: At least one value must be in output - eni_count: Count ENIs with operator """ - assertion_type = assertion['type'] + assertion_type = assertion["type"] - if assertion_type == 'contains': - assert assertion['value'] in output, ( + if assertion_type == "contains": + assert assertion["value"] in output, ( f"Command '{command}': Expected '{assertion['value']}' in output.\n" f"Message: {assertion.get('message', '')}" ) - elif assertion_type == 'not_contains': - assert assertion['value'] not in output, ( + elif assertion_type == "not_contains": + assert assertion["value"] not in output, ( f"Command '{command}': Did NOT expect '{assertion['value']}' in output.\n" f"Message: {assertion.get('message', '')}\n" f"This assertion has severity: {assertion.get('severity', 'normal')}" ) - elif assertion_type == 'contains_any': - values = assertion['values'] + elif assertion_type == "contains_any": + values = assertion["values"] found = any(v in output for v in values) assert found, ( f"Command '{command}': Expected at least one of {values} in output.\n" f"Message: {assertion.get('message', '')}" ) - elif assertion_type == 'eni_count': - eni_count = output.count('eni-') - operator = assertion['operator'] - expected = assertion['value'] + elif assertion_type == "eni_count": + eni_count = output.count("eni-") + operator = assertion["operator"] + expected = assertion["value"] - if operator == '<=': + if operator == "<=": assert eni_count <= expected, ( f"Command '{command}': ENI count {eni_count} not <= {expected}.\n" f"Message: {assertion.get('message', '')}" ) - elif operator == '==': + elif operator == "==": assert eni_count == expected, f"ENI count {eni_count} != {expected}" - elif operator == '>=': + elif operator == ">=": assert eni_count >= expected, f"ENI count {eni_count} not >= {expected}" else: diff --git a/tests/interactive_routing_cache_test.py b/tests/interactive_routing_cache_test.py index 71ac13a..2bc59e3 100755 --- a/tests/interactive_routing_cache_test.py +++ b/tests/interactive_routing_cache_test.py @@ -27,7 +27,7 @@ def main(): # Test 1: Create routing cache print("\n1️⃣ Creating routing cache...") print("-" * 70) - shell.onecmd('create_routing_cache') + shell.onecmd("create_routing_cache") # Validate cache created cache = shell._cache.get("routing-cache", {}) @@ -36,7 +36,7 @@ def main(): cloudwan_count = len(cache.get("cloudwan", {}).get("routes", [])) total = vpc_count + tgw_count + cloudwan_count - print(f"\n✓ Cache created:") + print("\n✓ Cache created:") print(f" VPC: {vpc_count} routes") print(f" Transit Gateway: {tgw_count} routes") print(f" Cloud WAN: {cloudwan_count} routes") @@ -50,22 +50,22 @@ def main(): # Test 2: Show summary print("\n2️⃣ Showing cache summary...") print("-" * 70) - shell.onecmd('show routing-cache') + shell.onecmd("show routing-cache") # Test 3: Show Transit Gateway routes print("\n3️⃣ Showing Transit Gateway routes...") print("-" * 70) - shell.onecmd('show routing-cache transit-gateway') + shell.onecmd("show routing-cache transit-gateway") # Test 4: Show Cloud WAN routes print("\n4️⃣ Showing Cloud WAN routes...") print("-" * 70) - shell.onecmd('show routing-cache cloud-wan') + shell.onecmd("show routing-cache cloud-wan") # Test 5: Save to SQLite print("\n5️⃣ Saving to SQLite database...") print("-" * 70) - shell.onecmd('save_routing_cache') + shell.onecmd("save_routing_cache") # Validate DB file exists db = CacheDB() @@ -76,7 +76,7 @@ def main(): print("\n6️⃣ Loading from SQLite database...") print("-" * 70) shell._cache.clear() # Clear memory cache - shell.onecmd('load_routing_cache') + shell.onecmd("load_routing_cache") # Validate loaded correctly loaded_cache = shell._cache.get("routing-cache", {}) @@ -89,7 +89,7 @@ def main(): # Test 7: Show all routes print("\n7️⃣ Showing all routes...") print("-" * 70) - shell.onecmd('show routing-cache all') + shell.onecmd("show routing-cache all") # Final summary print("\n" + "=" * 70) @@ -116,5 +116,6 @@ def main(): except Exception as e: print(f"\n❌ ERROR: {e}") import traceback + traceback.print_exc() sys.exit(1) diff --git a/tests/test_cloudwan_branch.py b/tests/test_cloudwan_branch.py index cf51cd4..b84229d 100644 --- a/tests/test_cloudwan_branch.py +++ b/tests/test_cloudwan_branch.py @@ -3,7 +3,7 @@ Tests the nested context chain: 1. Root: show global-networks 2. set global-network 1 → global-network context -3. show core-networks +3. show core-networks 4. set core-network 1 → core-network context 5. show route-tables 6. set route-table 1 → route-table context @@ -45,21 +45,56 @@ "CreatedAt": "2024-01-10T09:00:00+00:00", "State": "AVAILABLE", "Segments": [ - {"Name": "production", "EdgeLocations": ["eu-west-1", "us-east-1"], "SharedSegments": []}, - {"Name": "shared-services", "EdgeLocations": ["eu-west-1", "us-east-1"], "SharedSegments": ["production"]}, + { + "Name": "production", + "EdgeLocations": ["eu-west-1", "us-east-1"], + "SharedSegments": [], + }, + { + "Name": "shared-services", + "EdgeLocations": ["eu-west-1", "us-east-1"], + "SharedSegments": ["production"], + }, ], "Edges": [ - {"EdgeLocation": "eu-west-1", "Asn": 64520, "InsideCidrBlocks": ["169.254.0.0/24"]}, - {"EdgeLocation": "us-east-1", "Asn": 64521, "InsideCidrBlocks": ["169.254.1.0/24"]}, + { + "EdgeLocation": "eu-west-1", + "Asn": 64520, + "InsideCidrBlocks": ["169.254.0.0/24"], + }, + { + "EdgeLocation": "us-east-1", + "Asn": 64521, + "InsideCidrBlocks": ["169.254.1.0/24"], + }, ], "Tags": [{"Key": "Name", "Value": "prod-core-network"}], }, ] CORE_NETWORK_ROUTES = [ - {"DestinationCidrBlock": "10.0.0.0/16", "Destinations": [{"CoreNetworkAttachmentId": "attachment-001", "SegmentName": "production"}], "Type": "PROPAGATED", "State": "ACTIVE"}, - {"DestinationCidrBlock": "10.1.0.0/16", "Destinations": [{"CoreNetworkAttachmentId": "attachment-002", "SegmentName": "production"}], "Type": "PROPAGATED", "State": "ACTIVE"}, - {"DestinationCidrBlock": "10.2.0.0/16", "Destinations": [], "Type": "PROPAGATED", "State": "BLACKHOLE"}, + { + "DestinationCidrBlock": "10.0.0.0/16", + "Destinations": [ + {"CoreNetworkAttachmentId": "attachment-001", "SegmentName": "production"} + ], + "Type": "PROPAGATED", + "State": "ACTIVE", + }, + { + "DestinationCidrBlock": "10.1.0.0/16", + "Destinations": [ + {"CoreNetworkAttachmentId": "attachment-002", "SegmentName": "production"} + ], + "Type": "PROPAGATED", + "State": "ACTIVE", + }, + { + "DestinationCidrBlock": "10.2.0.0/16", + "Destinations": [], + "Type": "PROPAGATED", + "State": "BLACKHOLE", + }, ] @@ -67,6 +102,7 @@ # FIXTURES # ============================================================================= + @pytest.fixture def mock_shell(): """Create mock shell with required attributes.""" @@ -81,6 +117,7 @@ def mock_shell(): # TEST: ROOT - show global-networks # ============================================================================= + class TestRootShowGlobalNetworks: """Test show global-networks at root level.""" @@ -88,13 +125,16 @@ def test_show_global_networks_calls_api(self): """show global-networks should call describe_global_networks.""" with patch("aws_network_tools.modules.cloudwan.CloudWANClient") as MockClient: mock_nm = MagicMock() - mock_nm.describe_global_networks.return_value = {"GlobalNetworks": GLOBAL_NETWORKS} + mock_nm.describe_global_networks.return_value = { + "GlobalNetworks": GLOBAL_NETWORKS + } MockClient.return_value.nm = mock_nm - + from aws_network_tools.modules.cloudwan import CloudWANClient + client = CloudWANClient("default") result = client.nm.describe_global_networks() - + assert "GlobalNetworks" in result assert len(result["GlobalNetworks"]) == 2 @@ -110,6 +150,7 @@ def test_global_networks_have_required_fields(self): # TEST: GLOBAL-NETWORK CONTEXT - show core-networks # ============================================================================= + class TestGlobalNetworkContext: """Test commands in global-network context.""" @@ -119,24 +160,31 @@ def test_enter_global_network_context(self, mock_shell): mock_shell.current_context = "global-network" mock_shell.context_data["global_network"] = gn mock_shell.context_data["global_network_id"] = gn["GlobalNetworkId"] - + assert mock_shell.current_context == "global-network" - assert mock_shell.context_data["global_network_id"] == "global-network-0abc123def456" + assert ( + mock_shell.context_data["global_network_id"] + == "global-network-0abc123def456" + ) def test_show_core_networks_in_context(self): """show core-networks should list core networks for global network.""" with patch("aws_network_tools.modules.cloudwan.CloudWANClient") as MockClient: mock_nm = MagicMock() mock_nm.list_core_networks.return_value = { - "CoreNetworks": [{"CoreNetworkId": cn["CoreNetworkId"], "State": cn["State"]} for cn in CORE_NETWORKS] + "CoreNetworks": [ + {"CoreNetworkId": cn["CoreNetworkId"], "State": cn["State"]} + for cn in CORE_NETWORKS + ] } mock_nm.get_core_network.return_value = {"CoreNetwork": CORE_NETWORKS[0]} MockClient.return_value.nm = mock_nm - + from aws_network_tools.modules.cloudwan import CloudWANClient + client = CloudWANClient("default") result = client.nm.list_core_networks() - + assert len(result["CoreNetworks"]) >= 1 @@ -144,6 +192,7 @@ def test_show_core_networks_in_context(self): # TEST: CORE-NETWORK CONTEXT # ============================================================================= + class TestCoreNetworkContext: """Test commands in core-network context.""" @@ -153,7 +202,7 @@ def test_enter_core_network_context(self, mock_shell): mock_shell.current_context = "core-network" mock_shell.context_data["core_network"] = cn mock_shell.context_data["core_network_id"] = cn["CoreNetworkId"] - + assert mock_shell.current_context == "core-network" assert mock_shell.context_data["core_network_id"] == "core-network-0prod123456" @@ -161,7 +210,7 @@ def test_core_network_has_segments(self, mock_shell): """Core network should have segments (become route tables).""" cn = CORE_NETWORKS[0] mock_shell.context_data["core_network"] = cn - + segments = cn.get("Segments", []) assert len(segments) >= 1 assert segments[0]["Name"] == "production" @@ -170,30 +219,33 @@ def test_show_route_tables_lists_segment_edge_combos(self, mock_shell): """show route-tables should list segment/edge combinations.""" cn = CORE_NETWORKS[0] mock_shell.context_data["core_network"] = cn - + # Route tables = segment × edge combinations route_tables = [] for segment in cn.get("Segments", []): for edge in segment.get("EdgeLocations", []): route_tables.append({"segment": segment["Name"], "edge": edge}) - + assert len(route_tables) >= 2 # production has 2 edges def test_show_routes_calls_api(self): """show routes should call get_core_network_routes.""" with patch("aws_network_tools.modules.cloudwan.CloudWANClient") as MockClient: mock_nm = MagicMock() - mock_nm.get_core_network_routes.return_value = {"CoreNetworkRoutes": CORE_NETWORK_ROUTES} + mock_nm.get_core_network_routes.return_value = { + "CoreNetworkRoutes": CORE_NETWORK_ROUTES + } MockClient.return_value.nm = mock_nm - + from aws_network_tools.modules.cloudwan import CloudWANClient + client = CloudWANClient("default") result = client.nm.get_core_network_routes( CoreNetworkId="core-network-0prod123456", SegmentName="production", - EdgeLocation="eu-west-1" + EdgeLocation="eu-west-1", ) - + assert len(result["CoreNetworkRoutes"]) == 3 @@ -201,6 +253,7 @@ def test_show_routes_calls_api(self): # TEST: ROUTE-TABLE CONTEXT (nested under core-network) # ============================================================================= + class TestRouteTableContext: """Test commands in route-table context.""" @@ -211,7 +264,7 @@ def test_enter_route_table_context(self, mock_shell): mock_shell.context_data["core_network"] = cn mock_shell.context_data["segment"] = "production" mock_shell.context_data["edge"] = "eu-west-1" - + assert mock_shell.current_context == "route-table" assert mock_shell.context_data["segment"] == "production" assert mock_shell.context_data["edge"] == "eu-west-1" @@ -220,7 +273,7 @@ def test_show_routes_in_route_table(self, mock_shell): """show routes should show routes for specific segment/edge.""" mock_shell.context_data["segment"] = "production" mock_shell.context_data["edge"] = "eu-west-1" - + # Filter routes for this segment routes = [r for r in CORE_NETWORK_ROUTES if r["State"] == "ACTIVE"] assert len(routes) == 2 @@ -229,15 +282,17 @@ def test_find_prefix_action(self, mock_shell): """find_prefix should search routes in route table.""" mock_shell.current_context = "route-table" mock_shell.context_data["segment"] = "production" - + # Search for 10.0.0.0/16 - matching = [r for r in CORE_NETWORK_ROUTES if "10.0" in r["DestinationCidrBlock"]] + matching = [ + r for r in CORE_NETWORK_ROUTES if "10.0" in r["DestinationCidrBlock"] + ] assert len(matching) >= 1 def test_find_null_routes_action(self, mock_shell): """find_null_routes should find blackhole routes.""" mock_shell.current_context = "route-table" - + blackholes = [r for r in CORE_NETWORK_ROUTES if r["State"] == "BLACKHOLE"] assert len(blackholes) == 1 assert blackholes[0]["DestinationCidrBlock"] == "10.2.0.0/16" @@ -247,6 +302,7 @@ def test_find_null_routes_action(self, mock_shell): # TEST: FULL BRANCH TRAVERSAL # ============================================================================= + class TestFullBranchTraversal: """Test complete traversal: root → global-network → core-network → route-table.""" @@ -256,29 +312,29 @@ def test_full_navigation_chain(self, mock_shell): assert mock_shell.current_context is None gns = GLOBAL_NETWORKS assert len(gns) >= 1 - + # Step 2: set global-network 1 gn = gns[0] mock_shell.current_context = "global-network" mock_shell.context_data["global_network"] = gn mock_shell.context_data["global_network_id"] = gn["GlobalNetworkId"] assert mock_shell.current_context == "global-network" - + # Step 3: show core-networks (in global-network context) cns = CORE_NETWORKS assert len(cns) >= 1 - + # Step 4: set core-network 1 cn = cns[0] mock_shell.current_context = "core-network" mock_shell.context_data["core_network"] = cn mock_shell.context_data["core_network_id"] = cn["CoreNetworkId"] assert mock_shell.current_context == "core-network" - + # Step 5: show route-tables (segments × edges) segments = cn.get("Segments", []) assert len(segments) >= 1 - + # Step 6: set route-table 1 (first segment/edge combo) segment = segments[0] edge = segment["EdgeLocations"][0] @@ -286,14 +342,19 @@ def test_full_navigation_chain(self, mock_shell): mock_shell.context_data["segment"] = segment["Name"] mock_shell.context_data["edge"] = edge assert mock_shell.current_context == "route-table" - + # Step 7: show routes routes = CORE_NETWORK_ROUTES assert len(routes) >= 1 - + # Verify full context chain is preserved - assert mock_shell.context_data.get("global_network_id") == "global-network-0abc123def456" - assert mock_shell.context_data.get("core_network_id") == "core-network-0prod123456" + assert ( + mock_shell.context_data.get("global_network_id") + == "global-network-0abc123def456" + ) + assert ( + mock_shell.context_data.get("core_network_id") == "core-network-0prod123456" + ) assert mock_shell.context_data.get("segment") == "production" assert mock_shell.context_data.get("edge") == "eu-west-1" @@ -307,20 +368,20 @@ def test_exit_back_through_contexts(self, mock_shell): "segment": "production", "edge": "eu-west-1", } - + # Exit to core-network mock_shell.current_context = "core-network" del mock_shell.context_data["segment"] del mock_shell.context_data["edge"] assert mock_shell.current_context == "core-network" assert "core_network_id" in mock_shell.context_data - + # Exit to global-network mock_shell.current_context = "global-network" del mock_shell.context_data["core_network_id"] assert mock_shell.current_context == "global-network" assert "global_network_id" in mock_shell.context_data - + # Exit to root mock_shell.current_context = None mock_shell.context_data = {} diff --git a/tests/test_cloudwan_handlers.py b/tests/test_cloudwan_handlers.py index 3836cec..f9c1ea3 100644 --- a/tests/test_cloudwan_handlers.py +++ b/tests/test_cloudwan_handlers.py @@ -4,8 +4,7 @@ """ import pytest -from unittest.mock import MagicMock, patch, PropertyMock -from io import StringIO +from unittest.mock import MagicMock, patch # ============================================================================= # FIXTURES @@ -36,8 +35,16 @@ CORE_NETWORK_ROUTES_RESPONSE = { "CoreNetworkRoutes": [ - {"DestinationCidrBlock": "10.0.0.0/16", "State": "ACTIVE", "Type": "PROPAGATED"}, - {"DestinationCidrBlock": "10.1.0.0/16", "State": "ACTIVE", "Type": "PROPAGATED"}, + { + "DestinationCidrBlock": "10.0.0.0/16", + "State": "ACTIVE", + "Type": "PROPAGATED", + }, + { + "DestinationCidrBlock": "10.1.0.0/16", + "State": "ACTIVE", + "Type": "PROPAGATED", + }, ] } @@ -56,6 +63,7 @@ def mock_shell(): # TEST: RootHandler._show_global_networks # ============================================================================= + class TestRootHandlerGlobalNetworks: """Test RootHandlersMixin._show_global_networks.""" @@ -66,16 +74,19 @@ def test_show_global_networks_fetches_and_displays(self, MockClient, mock_shell) mock_nm = MagicMock() mock_nm.describe_global_networks.return_value = GLOBAL_NETWORKS_API_RESPONSE MockClient.return_value.nm = mock_nm - + # Call the actual method logic (simulating what the mixin does) client = MockClient("default") gns = [] for gn in client.nm.describe_global_networks().get("GlobalNetworks", []): if gn.get("State") == "AVAILABLE": gn_id = gn["GlobalNetworkId"] - name = next((t["Value"] for t in gn.get("Tags", []) if t["Key"] == "Name"), gn_id) + name = next( + (t["Value"] for t in gn.get("Tags", []) if t["Key"] == "Name"), + gn_id, + ) gns.append({"id": gn_id, "name": name, "state": gn.get("State", "")}) - + assert len(gns) == 1 assert gns[0]["id"] == "global-network-0abc123" assert gns[0]["name"] == "prod-global" @@ -85,6 +96,7 @@ def test_show_global_networks_fetches_and_displays(self, MockClient, mock_shell) # TEST: CloudWANHandler._show_core_networks # ============================================================================= + class TestCloudWANHandlerCoreNetworks: """Test CloudWANHandlersMixin._show_core_networks.""" @@ -92,16 +104,16 @@ class TestCloudWANHandlerCoreNetworks: def test_show_core_networks_in_global_context(self, MockClient, mock_shell): """_show_core_networks should list core networks for current global network.""" MockClient.return_value.discover.return_value = CORE_NETWORKS_DISCOVER_RESPONSE - + # Simulate being in global-network context - ctx_type = "global-network" + _ctx_type = "global-network" # Reserved for future use ctx_id = "global-network-0abc123" - + # Call discover and filter (simulating what the mixin does) client = MockClient("default") all_cn = client.discover() cns = [cn for cn in all_cn if cn["global_network_id"] == ctx_id] - + assert len(cns) == 1 assert cns[0]["id"] == "core-network-0prod123" assert cns[0]["name"] == "prod-core-network" @@ -111,6 +123,7 @@ def test_show_core_networks_in_global_context(self, MockClient, mock_shell): # TEST: CloudWANHandler._show_routes # ============================================================================= + class TestCloudWANHandlerRoutes: """Test CloudWANHandler route commands.""" @@ -120,19 +133,17 @@ def test_show_routes_in_core_network_context(self, MockClient): mock_nm = MagicMock() mock_nm.get_core_network_routes.return_value = CORE_NETWORK_ROUTES_RESPONSE MockClient.return_value.nm = mock_nm - + # Simulate being in core-network context ctx_id = "core-network-0prod123" segment = "production" edge = "eu-west-1" - + client = MockClient("default") result = client.nm.get_core_network_routes( - CoreNetworkId=ctx_id, - SegmentName=segment, - EdgeLocation=edge + CoreNetworkId=ctx_id, SegmentName=segment, EdgeLocation=edge ) - + routes = result.get("CoreNetworkRoutes", []) assert len(routes) == 2 assert routes[0]["DestinationCidrBlock"] == "10.0.0.0/16" @@ -142,6 +153,7 @@ def test_show_routes_in_core_network_context(self, MockClient): # TEST: Context Entry Chain # ============================================================================= + class TestContextEntryChain: """Test the context entry chain for CloudWAN branch.""" @@ -151,7 +163,7 @@ def test_set_global_network_enters_context(self, MockClient, mock_shell): mock_nm = MagicMock() mock_nm.describe_global_networks.return_value = GLOBAL_NETWORKS_API_RESPONSE MockClient.return_value.nm = mock_nm - + # Simulate the set global-network flow client = MockClient("default") gns_response = client.nm.describe_global_networks() @@ -159,12 +171,15 @@ def test_set_global_network_enters_context(self, MockClient, mock_shell): for gn in gns_response.get("GlobalNetworks", []): if gn.get("State") == "AVAILABLE": gn_id = gn["GlobalNetworkId"] - name = next((t["Value"] for t in gn.get("Tags", []) if t["Key"] == "Name"), gn_id) + name = next( + (t["Value"] for t in gn.get("Tags", []) if t["Key"] == "Name"), + gn_id, + ) gns.append({"id": gn_id, "name": name, "state": gn.get("State", "")}) - + # Select first global network selected = gns[0] - + # Verify we have the right data to enter context assert selected["id"] == "global-network-0abc123" assert selected["name"] == "prod-global" @@ -173,21 +188,23 @@ def test_set_global_network_enters_context(self, MockClient, mock_shell): def test_set_core_network_enters_context(self, MockClient, mock_shell): """set core-network should call _enter with correct params.""" MockClient.return_value.discover.return_value = CORE_NETWORKS_DISCOVER_RESPONSE - MockClient.return_value.get_policy_document.return_value = {"version": "2024-01"} - + MockClient.return_value.get_policy_document.return_value = { + "version": "2024-01" + } + # Simulate being in global-network context ctx_id = "global-network-0abc123" - + client = MockClient("default") all_cn = client.discover() cns = [cn for cn in all_cn if cn["global_network_id"] == ctx_id] - + # Select first core network selected = cns[0] - + # Fetch policy for full context policy = client.get_policy_document(selected["id"]) - + # Verify we have the right data to enter context assert selected["id"] == "core-network-0prod123" assert selected["name"] == "prod-core-network" @@ -198,6 +215,7 @@ def test_set_core_network_enters_context(self, MockClient, mock_shell): # TEST: Route Table Context # ============================================================================= + class TestRouteTableContext: """Test route-table context entry and commands.""" @@ -206,13 +224,13 @@ def test_route_tables_derived_from_segments_and_edges(self): cn = CORE_NETWORKS_DISCOVER_RESPONSE[0] segments = cn["segments"] edges = cn["edges"] - + # Generate route table combinations route_tables = [] for segment in segments: for edge in edges: route_tables.append({"segment": segment, "edge": edge}) - + # 2 segments × 2 edges = 4 route tables assert len(route_tables) == 4 assert {"segment": "production", "edge": "eu-west-1"} in route_tables @@ -224,22 +242,18 @@ def test_show_routes_in_route_table_context(self, MockClient): mock_nm = MagicMock() mock_nm.get_core_network_routes.return_value = CORE_NETWORK_ROUTES_RESPONSE MockClient.return_value.nm = mock_nm - + # In route-table context with specific segment/edge cn_id = "core-network-0prod123" segment = "production" edge = "eu-west-1" - + client = MockClient("default") result = client.nm.get_core_network_routes( - CoreNetworkId=cn_id, - SegmentName=segment, - EdgeLocation=edge + CoreNetworkId=cn_id, SegmentName=segment, EdgeLocation=edge ) - + mock_nm.get_core_network_routes.assert_called_once_with( - CoreNetworkId=cn_id, - SegmentName=segment, - EdgeLocation=edge + CoreNetworkId=cn_id, SegmentName=segment, EdgeLocation=edge ) assert len(result["CoreNetworkRoutes"]) == 2 diff --git a/tests/test_cloudwan_issue3.py b/tests/test_cloudwan_issue3.py index 7dd194e..61819cb 100644 --- a/tests/test_cloudwan_issue3.py +++ b/tests/test_cloudwan_issue3.py @@ -59,9 +59,7 @@ def mock_cloudwan_client(sample_policy_document): with patch("boto3.Session") as MockSession: MockSession.return_value = MagicMock() - with patch( - "aws_network_tools.modules.cloudwan.CloudWANClient" - ) as MockClient: + with patch("aws_network_tools.modules.cloudwan.CloudWANClient") as MockClient: # Create mock instance mock_instance = MagicMock() mock_instance.get_policy_document.return_value = sample_policy_document @@ -84,7 +82,7 @@ def mock_cloudwan_client(sample_policy_document): @pytest.fixture def shell_in_global_context(mock_cloudwan_client): """Create shell in global-network context with cached core networks. - + Note: Depends on mock_cloudwan_client to ensure mock is active before shell creation. """ from aws_network_tools.shell import AWSNetShell @@ -127,6 +125,7 @@ def test_debug_mock_applied(self, mock_cloudwan_client): MockClass, mock_instance = mock_cloudwan_client # Import and check if the class is mocked from aws_network_tools.modules import cloudwan + print(f"\nCloudWANClient type: {type(cloudwan.CloudWANClient)}") print(f"Is mock: {cloudwan.CloudWANClient is MockClass}") print(f"Mock class: {MockClass}") @@ -145,22 +144,22 @@ def test_debug_fetch_function(self, shell_in_global_context): "state": "AVAILABLE", } - print(f"\nManual test of fetch_full_cn logic:") + print("\nManual test of fetch_full_cn logic:") print(f"1. cloudwan module: {cloudwan}") print(f"2. CloudWANClient: {cloudwan.CloudWANClient}") print(f"3. Is mock: {cloudwan.CloudWANClient is MockClass}") try: - print(f"4. Creating client...") + print("4. Creating client...") client = cloudwan.CloudWANClient(shell.profile) print(f"5. Client created: {client}, type: {type(client)}") print(f"6. client == mock_instance? {client is mock_instance}") - print(f"7. Calling get_policy_document...") + print("7. Calling get_policy_document...") policy = client.get_policy_document(cn["id"]) print(f"8. Policy returned: {policy is not None}") - print(f"9. Creating full_data...") + print("9. Creating full_data...") full_data = dict(cn) full_data["policy"] = policy print(f"10. Success! full_data keys: {list(full_data.keys())}") @@ -168,6 +167,7 @@ def test_debug_fetch_function(self, shell_in_global_context): except Exception as e: print(f"ERROR at some step: {type(e).__name__}: {e}") import traceback + traceback.print_exc() def test_set_core_network_fetches_policy( @@ -185,7 +185,9 @@ def test_set_core_network_fetches_policy( print(f"\nCaptured stderr:\n{captured.err}") # Verify: Context was entered - assert shell.ctx_type == "core-network", f"Context not entered. Output: {captured.out}" + assert shell.ctx_type == "core-network", ( + f"Context not entered. Output: {captured.out}" + ) assert shell.ctx_id == "core-network-05124a7b0180598f2" # Verify: Policy was fetched and stored in context data @@ -360,9 +362,7 @@ def test_segments_table_columns( class TestCloudWANPolicyDisplay: """Test policy display formatting.""" - def test_policy_json_format( - self, shell_in_global_context, sample_policy_document - ): + def test_policy_json_format(self, shell_in_global_context, sample_policy_document): """Test that policy is displayed as JSON.""" shell, (MockClass, mock_instance) = shell_in_global_context @@ -377,6 +377,7 @@ def test_policy_json_format( # Verify JSON serializable import json + json_str = json.dumps(policy, indent=2, default=str) assert "{" in json_str assert "}" in json_str diff --git a/tests/test_cloudwan_issues.py b/tests/test_cloudwan_issues.py index 6f8e7e3..56e638f 100644 --- a/tests/test_cloudwan_issues.py +++ b/tests/test_cloudwan_issues.py @@ -114,7 +114,10 @@ def test_policy_data_access(self): "policy": { "segments": [ {"name": "default", "edge-locations": ["us-east-1"]}, - {"name": "production", "edge-locations": ["us-east-1", "us-west-2"]}, + { + "name": "production", + "edge-locations": ["us-east-1", "us-west-2"], + }, ] } } @@ -124,8 +127,9 @@ def test_policy_data_access(self): # Should NOT print "No segments found" calls = [str(c) for c in mock_self.console.print.call_args_list] - no_segments_called = any("No segments" in c for c in calls) + _no_segments_called = any("No segments" in c for c in calls) # Note: Since we're mocking console, this test verifies logic flow + assert _no_segments_called is not None # Verify computation completed class TestPolicyChangeEventsDatetime: diff --git a/tests/test_command_graph/base_context_test.py b/tests/test_command_graph/base_context_test.py index 8bc7f41..888725a 100644 --- a/tests/test_command_graph/base_context_test.py +++ b/tests/test_command_graph/base_context_test.py @@ -86,11 +86,15 @@ def show_set_sequence(self, resource_type: str, index: int = 1) -> dict: set_cmd = self._get_set_command(resource_type) # Execute show→set sequence - results = self.execute_sequence([show_cmd, f'{set_cmd} {index}']) + results = self.execute_sequence([show_cmd, f"{set_cmd} {index}"]) # Validate both commands succeeded - assert results[0]['exit_code'] == 0, f"show command failed: {results[0].get('output', '')}" - assert results[1]['exit_code'] == 0, f"set command failed: {results[1].get('output', '')}" + assert results[0]["exit_code"] == 0, ( + f"show command failed: {results[0].get('output', '')}" + ) + assert results[1]["exit_code"] == 0, ( + f"set command failed: {results[1].get('output', '')}" + ) return results[1] @@ -100,19 +104,19 @@ def _get_show_command(self, resource_type: str) -> str: Handles pluralization and command name mapping. """ # Special cases with underscore commands - if resource_type == 'transit-gateway': - return 'show transit_gateways' # Underscore, not hyphen - elif resource_type == 'global-network': - return 'show global-networks' - elif resource_type == 'core-network': - return 'show core-networks' + if resource_type == "transit-gateway": + return "show transit_gateways" # Underscore, not hyphen + elif resource_type == "global-network": + return "show global-networks" + elif resource_type == "core-network": + return "show core-networks" # Default: add 's' for plural - return f'show {resource_type}s' + return f"show {resource_type}s" def _get_set_command(self, resource_type: str) -> str: """Get correct set command for resource type.""" - return f'set {resource_type}' + return f"set {resource_type}" def assert_context(self, expected_type: str, has_data: bool = False): """Assert shell is in expected context type. @@ -154,7 +158,9 @@ def assert_context_depth(self, expected_depth: int): f"Expected context depth {expected_depth}, got {actual_depth}" ) - def assert_resource_count(self, output: str, resource_name: str, min_count: int = 1): + def assert_resource_count( + self, output: str, resource_name: str, min_count: int = 1 + ): """Assert output contains minimum resource count. Args: diff --git a/tests/test_command_graph/conftest.py b/tests/test_command_graph/conftest.py index f7088d2..646b969 100644 --- a/tests/test_command_graph/conftest.py +++ b/tests/test_command_graph/conftest.py @@ -8,7 +8,7 @@ """ import pytest -from unittest.mock import MagicMock, patch, PropertyMock +from unittest.mock import MagicMock, patch from io import StringIO from rich.console import Console @@ -16,36 +16,18 @@ from tests.fixtures import ( GLOBAL_NETWORK_FIXTURES, VPC_FIXTURES, - SUBNET_FIXTURES, - ROUTE_TABLE_FIXTURES, - SECURITY_GROUP_FIXTURES, - NACL_FIXTURES, TGW_FIXTURES, TGW_ATTACHMENT_FIXTURES, TGW_ROUTE_TABLE_FIXTURES, CLOUDWAN_FIXTURES, - CLOUDWAN_ATTACHMENT_FIXTURES, - CLOUDWAN_SEGMENT_FIXTURES, EC2_INSTANCE_FIXTURES, ENI_FIXTURES, ELB_FIXTURES, TARGET_GROUP_FIXTURES, LISTENER_FIXTURES, VPN_CONNECTION_FIXTURES, - CUSTOMER_GATEWAY_FIXTURES, - VPN_GATEWAY_FIXTURES, NETWORK_FIREWALL_FIXTURES, - FIREWALL_POLICY_FIXTURES, - RULE_GROUP_FIXTURES, - IGW_FIXTURES, - NAT_GATEWAY_FIXTURES, - get_global_network_by_id, get_vpc_detail, - get_tgw_detail, - get_core_network_detail, - get_elb_detail, - get_vpn_detail, - get_firewall_detail, ) @@ -203,16 +185,30 @@ def get_vpc_detail(self, vpc_id, region=None): # Transform to VPCClient format return { "id": vpc_id, - "name": next((t["Value"] for t in vpc_data.get("Tags", []) if t["Key"] == "Name"), vpc_id), + "name": next( + ( + t["Value"] + for t in vpc_data.get("Tags", []) + if t["Key"] == "Name" + ), + vpc_id, + ), "region": region or self._get_region_from_id(vpc_id), "cidrs": [vpc_data["CidrBlock"]], "azs": [], "subnets": [ { "id": s["SubnetId"], - "name": next((t["Value"] for t in s.get("Tags", []) if t["Key"] == "Name"), s["SubnetId"]), + "name": next( + ( + t["Value"] + for t in s.get("Tags", []) + if t["Key"] == "Name" + ), + s["SubnetId"], + ), "az": s["AvailabilityZone"], - "cidr": s["CidrBlock"] + "cidr": s["CidrBlock"], } for s in fixture_detail["subnets"] ], @@ -221,9 +217,18 @@ def get_vpc_detail(self, vpc_id, region=None): "route_tables": [ { "id": rt["RouteTableId"], - "name": next((t["Value"] for t in rt.get("Tags", []) if t["Key"] == "Name"), rt["RouteTableId"]), - "main": any(a.get("Main", False) for a in rt.get("Associations", [])), - "routes": [] + "name": next( + ( + t["Value"] + for t in rt.get("Tags", []) + if t["Key"] == "Name" + ), + rt["RouteTableId"], + ), + "main": any( + a.get("Main", False) for a in rt.get("Associations", []) + ), + "routes": [], } for rt in fixture_detail["route_tables"] ], @@ -231,14 +236,21 @@ def get_vpc_detail(self, vpc_id, region=None): { "id": sg["GroupId"], "name": sg.get("GroupName", ""), - "description": sg.get("Description", "") + "description": sg.get("Description", ""), } for sg in fixture_detail["security_groups"] ], "nacls": [ { "id": nacl["NetworkAclId"], - "name": next((t["Value"] for t in nacl.get("Tags", []) if t["Key"] == "Name"), nacl["NetworkAclId"]) + "name": next( + ( + t["Value"] + for t in nacl.get("Tags", []) + if t["Key"] == "Name" + ), + nacl["NetworkAclId"], + ), } for nacl in fixture_detail["nacls"] ], @@ -246,7 +258,7 @@ def get_vpc_detail(self, vpc_id, region=None): "endpoints": [], "encrypted": [], "no_ingress": [], - "tags": {} + "tags": {}, } def _get_region_from_id(self, vpc_id: str) -> str: @@ -279,7 +291,11 @@ def discover(self): { "id": att["TransitGatewayAttachmentId"], "name": next( - (t["Value"] for t in att.get("Tags", []) if t["Key"] == "Name"), + ( + t["Value"] + for t in att.get("Tags", []) + if t["Key"] == "Name" + ), att["TransitGatewayAttachmentId"], ), "type": att.get("ResourceType", ""), @@ -294,7 +310,11 @@ def discover(self): { "id": rt["TransitGatewayRouteTableId"], "name": next( - (t["Value"] for t in rt.get("Tags", []) if t["Key"] == "Name"), + ( + t["Value"] + for t in rt.get("Tags", []) + if t["Key"] == "Name" + ), rt["TransitGatewayRouteTableId"], ), "routes": [], @@ -343,6 +363,7 @@ def __init__(self, profile=None): def _setup_nm_client(self): """Setup nm client to return fixture data.""" + # Mock describe_global_networks - return ALL global networks from fixtures def mock_describe_global_networks(): return {"GlobalNetworks": list(GLOBAL_NETWORK_FIXTURES.values())} @@ -354,22 +375,30 @@ def discover(self): core_networks = [] for cn_id, cn_data in CLOUDWAN_FIXTURES.items(): # Format matches what CloudWANClient.discover() returns - core_networks.append({ - "id": cn_data["CoreNetworkId"], - "name": next( - (t["Value"] for t in cn_data.get("Tags", []) if t["Key"] == "Name"), - cn_id - ), - "arn": cn_data["CoreNetworkArn"], - "global_network_id": cn_data["GlobalNetworkId"], - "state": cn_data["State"], - "regions": [edge["EdgeLocation"] for edge in cn_data.get("Edges", [])], - "segments": [seg for seg in cn_data.get("Segments", [])], - "nfgs": [], # Network Function Groups - required by CloudWANDisplay.show_detail - "route_tables": [], - "policy": None, - "core_networks": [], # For compatibility - }) + core_networks.append( + { + "id": cn_data["CoreNetworkId"], + "name": next( + ( + t["Value"] + for t in cn_data.get("Tags", []) + if t["Key"] == "Name" + ), + cn_id, + ), + "arn": cn_data["CoreNetworkArn"], + "global_network_id": cn_data["GlobalNetworkId"], + "state": cn_data["State"], + "regions": [ + edge["EdgeLocation"] for edge in cn_data.get("Edges", []) + ], + "segments": [seg for seg in cn_data.get("Segments", [])], + "nfgs": [], # Network Function Groups - required by CloudWANDisplay.show_detail + "route_tables": [], + "policy": None, + "core_networks": [], # For compatibility + } + ) return core_networks def list_connect_peers(self, cn_id): @@ -380,17 +409,23 @@ def list_connect_peers(self, cn_id): for peer_id, peer_data in CONNECT_PEER_FIXTURES.items(): if peer_data.get("CoreNetworkId") == cn_id: config = peer_data.get("Configuration", {}) - peers.append({ - "id": peer_data["ConnectPeerId"], - "name": next( - (t["Value"] for t in peer_data.get("Tags", []) if t["Key"] == "Name"), - peer_id - ), - "state": peer_data["State"], - "edge_location": peer_data.get("EdgeLocation", ""), - "protocol": config.get("Protocol", "GRE"), - "bgp_configurations": config.get("BgpConfigurations", []), - }) + peers.append( + { + "id": peer_data["ConnectPeerId"], + "name": next( + ( + t["Value"] + for t in peer_data.get("Tags", []) + if t["Key"] == "Name" + ), + peer_id, + ), + "state": peer_data["State"], + "edge_location": peer_data.get("EdgeLocation", ""), + "protocol": config.get("Protocol", "GRE"), + "bgp_configurations": config.get("BgpConfigurations", []), + } + ) return peers def list_connect_attachments(self, cn_id): @@ -401,17 +436,23 @@ def list_connect_attachments(self, cn_id): for att_id, att_data in CONNECT_ATTACHMENT_FIXTURES.items(): if att_data.get("CoreNetworkId") == cn_id: options = att_data.get("ConnectOptions", {}) - attachments.append({ - "id": att_data["AttachmentId"], - "name": next( - (t["Value"] for t in att_data.get("Tags", []) if t["Key"] == "Name"), - att_id - ), - "state": att_data["State"], - "edge_location": att_data.get("EdgeLocation", ""), - "segment": att_data.get("SegmentName", ""), - "protocol": options.get("Protocol", "GRE"), - }) + attachments.append( + { + "id": att_data["AttachmentId"], + "name": next( + ( + t["Value"] + for t in att_data.get("Tags", []) + if t["Key"] == "Name" + ), + att_id, + ), + "state": att_data["State"], + "edge_location": att_data.get("EdgeLocation", ""), + "segment": att_data.get("SegmentName", ""), + "protocol": options.get("Protocol", "GRE"), + } + ) return attachments def get_core_network_detail(self, cn_id): @@ -421,6 +462,7 @@ def get_core_network_detail(self, cn_id): to get full detail before entering context. """ from tests.fixtures.cloudwan import get_core_network_detail + return get_core_network_detail(cn_id) def get_policy_document(self, cn_id): @@ -439,7 +481,6 @@ def get_policy_document(self, cn_id): def list_policy_versions(self, cn_id): """Return policy versions for core network.""" - from tests.fixtures.cloudwan import CLOUDWAN_POLICY_FIXTURE cn_data = CLOUDWAN_FIXTURES.get(cn_id) if not cn_data: @@ -512,7 +553,11 @@ def get_instance_detail(self, instance_id, region=None): return { "id": instance_data["InstanceId"], "name": next( - (t["Value"] for t in instance_data.get("Tags", []) if t["Key"] == "Name"), + ( + t["Value"] + for t in instance_data.get("Tags", []) + if t["Key"] == "Name" + ), instance_id, ), "type": instance_data["InstanceType"], @@ -523,7 +568,10 @@ def get_instance_detail(self, instance_id, region=None): "subnet_id": instance_data["SubnetId"], "private_ip": instance_data["PrivateIpAddress"], "enis": [ - {"id": eni["NetworkInterfaceId"], "private_ip": eni.get("PrivateIpAddress", "")} + { + "id": eni["NetworkInterfaceId"], + "private_ip": eni.get("PrivateIpAddress", ""), + } for eni in ENI_FIXTURES.values() if eni.get("Attachment", {}).get("InstanceId") == instance_id ], @@ -584,12 +632,12 @@ def get_elb_detail(self, elb_arn, region=None): # Get associated listeners and target groups listeners = [ { - "arn": l["ListenerArn"], - "port": l["Port"], - "protocol": l["Protocol"], + "arn": lis["ListenerArn"], + "port": lis["Port"], + "protocol": lis["Protocol"], } - for l in LISTENER_FIXTURES.values() - if l.get("LoadBalancerArn") == elb_arn + for lis in LISTENER_FIXTURES.values() + if lis.get("LoadBalancerArn") == elb_arn ] target_groups = [ @@ -701,11 +749,15 @@ def mock_all_clients( # Patch all client classes at module level monkeypatch.setattr("aws_network_tools.modules.vpc.VPCClient", mock_vpc_client) monkeypatch.setattr("aws_network_tools.modules.tgw.TGWClient", mock_tgw_client) - monkeypatch.setattr("aws_network_tools.modules.cloudwan.CloudWANClient", mock_cloudwan_client) + monkeypatch.setattr( + "aws_network_tools.modules.cloudwan.CloudWANClient", mock_cloudwan_client + ) monkeypatch.setattr("aws_network_tools.modules.ec2.EC2Client", mock_ec2_client) monkeypatch.setattr("aws_network_tools.modules.elb.ELBClient", mock_elb_client) monkeypatch.setattr("aws_network_tools.modules.vpn.VPNClient", mock_vpn_client) - monkeypatch.setattr("aws_network_tools.modules.anfw.ANFWClient", mock_firewall_client) + monkeypatch.setattr( + "aws_network_tools.modules.anfw.ANFWClient", mock_firewall_client + ) yield @@ -732,21 +784,21 @@ def assert_output_contains(result: dict, text: str): def assert_output_not_contains(result: dict, text: str): """Assert output does not contain specific text.""" - assert ( - text not in result["output"] - ), f"Did not expect '{text}' in output:\n{result['output']}" + assert text not in result["output"], ( + f"Did not expect '{text}' in output:\n{result['output']}" + ) def assert_context_type(shell, expected_ctx_type: str): """Assert shell is in expected context.""" - assert ( - shell.ctx_type == expected_ctx_type - ), f"Expected context '{expected_ctx_type}', got '{shell.ctx_type}'" + assert shell.ctx_type == expected_ctx_type, ( + f"Expected context '{expected_ctx_type}', got '{shell.ctx_type}'" + ) def assert_context_stack_depth(shell, expected_depth: int): """Assert context stack has expected depth.""" actual_depth = len(shell.context_stack) - assert ( - actual_depth == expected_depth - ), f"Expected stack depth {expected_depth}, got {actual_depth}" + assert actual_depth == expected_depth, ( + f"Expected stack depth {expected_depth}, got {actual_depth}" + ) diff --git a/tests/test_command_graph/test_base_context.py b/tests/test_command_graph/test_base_context.py index 7d05774..6c8b743 100644 --- a/tests/test_command_graph/test_base_context.py +++ b/tests/test_command_graph/test_base_context.py @@ -6,7 +6,7 @@ import pytest from .base_context_test import BaseContextTestCase -from .conftest import assert_success, assert_context_type +from .conftest import assert_context_type class TestBaseContextTestCaseInitialization(BaseContextTestCase): @@ -15,8 +15,8 @@ class TestBaseContextTestCaseInitialization(BaseContextTestCase): def test_shell_initialized(self): """Binary: Shell instance must be created.""" assert self.shell is not None - assert hasattr(self.shell, 'context_stack') - assert hasattr(self.shell, 'ctx_type') + assert hasattr(self.shell, "context_stack") + assert hasattr(self.shell, "ctx_type") def test_empty_context_stack(self): """Binary: Context stack starts empty.""" @@ -33,27 +33,29 @@ class TestExecuteSequenceMethod(BaseContextTestCase): def test_execute_single_command(self): """Binary: Single command executes successfully.""" - result = self.execute_sequence(['show vpcs']) - assert result['exit_code'] == 0 - assert 'vpc-' in result['output'] + result = self.execute_sequence(["show vpcs"]) + assert result["exit_code"] == 0 + assert "vpc-" in result["output"] def test_execute_multiple_commands(self): """Binary: Multiple commands execute in sequence.""" - results = self.execute_sequence(['show vpcs', 'show transit-gateways']) + results = self.execute_sequence(["show vpcs", "show transit-gateways"]) assert len(results) == 2 - assert all(r['exit_code'] == 0 for r in results) + assert all(r["exit_code"] == 0 for r in results) @pytest.mark.skip(reason="Shell doesn't fail on invalid index - returns silently") def test_execute_sequence_stops_on_failure(self): """Binary: Sequence stops on first failure.""" # Note: Current shell behavior doesn't return non-zero exit codes for invalid input # This test documents expected behavior for future implementation - results = self.execute_sequence([ - 'show vpcs', - 'set vpc 999', # Invalid index - should fail (but doesn't currently) - 'show transit_gateways' - ]) - assert results[0]['exit_code'] == 0 # First succeeds + results = self.execute_sequence( + [ + "show vpcs", + "set vpc 999", + "show transit_gateways", + ] # Invalid index - should fail (but doesn't currently) + ) + assert results[0]["exit_code"] == 0 # First succeeds class TestShowSetSequenceMethod(BaseContextTestCase): @@ -61,20 +63,22 @@ class TestShowSetSequenceMethod(BaseContextTestCase): def test_show_set_vpc_by_index(self): """Binary: show→set sequence enters VPC context.""" - self.show_set_sequence('vpc', 1) - assert_context_type(self.shell, 'vpc') + self.show_set_sequence("vpc", 1) + assert_context_type(self.shell, "vpc") def test_show_set_tgw_by_index(self): """Binary: show→set sequence enters TGW context.""" - self.show_set_sequence('transit-gateway', 1) - assert_context_type(self.shell, 'transit-gateway') + self.show_set_sequence("transit-gateway", 1) + assert_context_type(self.shell, "transit-gateway") def test_show_set_with_custom_index(self): """Binary: show→set works with index 2.""" - self.show_set_sequence('vpc', 2) - assert_context_type(self.shell, 'vpc') + self.show_set_sequence("vpc", 2) + assert_context_type(self.shell, "vpc") # Should be in second VPC from fixtures - assert 'staging' in self.shell.ctx_id.lower() or 'vpc-0stag' in self.shell.ctx_id + assert ( + "staging" in self.shell.ctx_id.lower() or "vpc-0stag" in self.shell.ctx_id + ) class TestAssertContextMethod(BaseContextTestCase): @@ -82,18 +86,18 @@ class TestAssertContextMethod(BaseContextTestCase): def test_assert_context_type_success(self): """Binary: Assert passes when context matches.""" - self.execute_sequence(['show vpcs', 'set vpc 1']) - self.assert_context('vpc') # Should pass without raising + self.execute_sequence(["show vpcs", "set vpc 1"]) + self.assert_context("vpc") # Should pass without raising def test_assert_context_type_failure(self): """Binary: Assert fails when context doesn't match.""" with pytest.raises(AssertionError): - self.assert_context('vpc') # Should fail - no context set + self.assert_context("vpc") # Should fail - no context set def test_assert_context_with_resource_check(self): """Binary: Assert validates resource data exists.""" - self.show_set_sequence('vpc', 1) - self.assert_context('vpc', has_data=True) + self.show_set_sequence("vpc", 1) + self.assert_context("vpc", has_data=True) class TestContextStackManagement(BaseContextTestCase): @@ -101,26 +105,28 @@ class TestContextStackManagement(BaseContextTestCase): def test_single_context_depth(self): """Binary: Single context = depth 1.""" - self.show_set_sequence('vpc', 1) + self.show_set_sequence("vpc", 1) self.assert_context_depth(1) @pytest.mark.skip(reason="Core-network context entry is Phase 2 fix") def test_nested_context_depth(self): """Binary: Nested contexts = depth 2+.""" # Global network → Core network - self.execute_sequence([ - 'show global-networks', - 'set global-network 2', # Use global network #2 (has core networks) - 'show core-networks', - 'set core-network 1' - ]) + self.execute_sequence( + [ + "show global-networks", + "set global-network 2", # Use global network #2 (has core networks) + "show core-networks", + "set core-network 1", + ] + ) self.assert_context_depth(2) def test_exit_context_reduces_depth(self): """Binary: exit command reduces stack depth.""" - self.show_set_sequence('vpc', 1) + self.show_set_sequence("vpc", 1) initial_depth = len(self.shell.context_stack) - self.execute_sequence(['exit']) + self.execute_sequence(["exit"]) assert len(self.shell.context_stack) == initial_depth - 1 @@ -131,27 +137,32 @@ class TestResourceCountAssertion(BaseContextTestCase): def test_assert_vpc_has_subnets(self): """Binary: VPC context must show subnets.""" # Skip until get_vpc_detail mock returns full subnet data - self.show_set_sequence('vpc', 1) - result = self.command_runner.run('show subnets') - self.assert_resource_count(result['output'], 'subnet-', min_count=1) + self.show_set_sequence("vpc", 1) + result = self.command_runner.run("show subnets") + self.assert_resource_count(result["output"], "subnet-", min_count=1) def test_assert_minimum_resource_count(self): """Binary: Min count validation works.""" - result = self.command_runner.run('show vpcs') - self.assert_resource_count(result['output'], 'vpc', min_count=3) + result = self.command_runner.run("show vpcs") + self.assert_resource_count(result["output"], "vpc", min_count=3) class TestErrorHandling(BaseContextTestCase): """Test error handling in base test case.""" - @pytest.mark.skip(reason="Shell returns exit_code 0 for unknown commands - shows help") + @pytest.mark.skip( + reason="Shell returns exit_code 0 for unknown commands - shows help" + ) def test_execute_sequence_captures_errors(self): """Binary: Errors are captured, not raised.""" # Note: Shell shows help for unknown commands with exit_code 0 - result = self.execute_sequence(['invalid-command']) + _result = self.execute_sequence(["invalid-command"]) # Would need actual error condition to test properly + assert _result is not None # Placeholder assertion - @pytest.mark.skip(reason="Shell returns exit_code 0 for unknown commands - shows help") + @pytest.mark.skip( + reason="Shell returns exit_code 0 for unknown commands - shows help" + ) def test_show_set_sequence_fails_gracefully(self): """Binary: show→set with invalid resource fails with assertion.""" # Note: Shell shows help for unknown commands with exit_code 0 diff --git a/tests/test_command_graph/test_cloudwan_branch.py b/tests/test_command_graph/test_cloudwan_branch.py index 2c3d249..ff12388 100644 --- a/tests/test_command_graph/test_cloudwan_branch.py +++ b/tests/test_command_graph/test_cloudwan_branch.py @@ -13,7 +13,6 @@ import pytest from tests.test_command_graph.conftest import ( assert_success, - assert_failure, assert_output_contains, assert_context_type, ) @@ -70,7 +69,7 @@ def test_set_global_network_by_id( """Test: set global-network using ID (enters global-network context).""" # Show first to populate data command_runner.run("show global-networks") - + result = command_runner.run("set global-network global-network-0prod123") assert_success(result) @@ -150,7 +149,9 @@ def test_set_core_network_by_number( # Step 1: Enter global-network context (show→set) command_runner.run("show global-networks") - command_runner.run("set global-network 1") # Use #1 (global-network-0prod123 has core network) + command_runner.run( + "set global-network 1" + ) # Use #1 (global-network-0prod123 has core network) # Step 2: Enter core-network context (show→set) command_runner.run("show core-networks") @@ -206,9 +207,7 @@ def test_show_core_network_detail( assert_output_contains(result, "core-network-0global123") assert_output_contains(result, "Segments") - def test_show_segments( - self, command_runner, isolated_shell, mock_cloudwan_client - ): + def test_show_segments(self, command_runner, isolated_shell, mock_cloudwan_client): """Test: show segments in core-network context.""" # Navigate to core-network context using proper pattern enter_core_network_context(command_runner) diff --git a/tests/test_command_graph/test_context_commands.py b/tests/test_command_graph/test_context_commands.py index 49f5dd1..38e2866 100644 --- a/tests/test_command_graph/test_context_commands.py +++ b/tests/test_command_graph/test_context_commands.py @@ -7,16 +7,13 @@ import pytest from .base_context_test import BaseContextTestCase from .test_data_generator import generate_phase3_test_data -from .conftest import assert_success class TestContextShowCommands(BaseContextTestCase): """Parametrized tests for context show commands.""" @pytest.mark.parametrize( - "test_data", - generate_phase3_test_data(), - ids=lambda t: t["test_id"] + "test_data", generate_phase3_test_data(), ids=lambda t: t["test_id"] ) def test_context_show_command(self, test_data): """Test show commands work in each context. @@ -41,7 +38,5 @@ def test_context_show_command(self, test_data): # Step 5: Validate expected content if test_data["min_count"] > 0: self.assert_resource_count( - result["output"], - test_data["expected"], - test_data["min_count"] + result["output"], test_data["expected"], test_data["min_count"] ) diff --git a/tests/test_command_graph/test_data_generator.py b/tests/test_command_graph/test_data_generator.py index 213bebf..33c5e14 100644 --- a/tests/test_command_graph/test_data_generator.py +++ b/tests/test_command_graph/test_data_generator.py @@ -33,26 +33,30 @@ def generate_phase3_test_data() -> List[Dict[str, Any]]: ] for cmd, pattern, min_count, desc in vpc_tests: - tests.append({ - "context": "vpc", - "index": 1, - "command": cmd, - "expected": pattern, - "min_count": min_count, - "description": desc, - "test_id": f"vpc_{cmd.replace('show ', '').replace('-', '_')}" - }) + tests.append( + { + "context": "vpc", + "index": 1, + "command": cmd, + "expected": pattern, + "min_count": min_count, + "description": desc, + "test_id": f"vpc_{cmd.replace('show ', '').replace('-', '_')}", + } + ) # Test multiple VPC indices - tests.append({ - "context": "vpc", - "index": 2, - "command": "show subnets", - "expected": "subnet-", - "min_count": 2, - "description": "Second VPC shows subnets", - "test_id": "vpc_subnets_index2" - }) + tests.append( + { + "context": "vpc", + "index": 2, + "command": "show subnets", + "expected": "subnet-", + "min_count": 2, + "description": "Second VPC shows subnets", + "test_id": "vpc_subnets_index2", + } + ) # ========================================================================= # TGW Context Tests (10 tests) @@ -63,15 +67,17 @@ def generate_phase3_test_data() -> List[Dict[str, Any]]: ] for cmd, pattern, min_count, desc in tgw_tests: - tests.append({ - "context": "transit-gateway", - "index": 1, - "command": cmd, - "expected": pattern, - "min_count": min_count, - "description": desc, - "test_id": f"tgw_{cmd.replace('show ', '').replace('-', '_')}" - }) + tests.append( + { + "context": "transit-gateway", + "index": 1, + "command": cmd, + "expected": pattern, + "min_count": min_count, + "description": desc, + "test_id": f"tgw_{cmd.replace('show ', '').replace('-', '_')}", + } + ) # ========================================================================= # EC2 Context Tests (6 tests) @@ -82,15 +88,17 @@ def generate_phase3_test_data() -> List[Dict[str, Any]]: ] for cmd, pattern, min_count, desc in ec2_tests: - tests.append({ - "context": "ec2-instance", - "index": 1, - "command": cmd, - "expected": pattern, - "min_count": min_count, - "description": desc, - "test_id": f"ec2_{cmd.replace('show ', '').replace('-', '_')}" - }) + tests.append( + { + "context": "ec2-instance", + "index": 1, + "command": cmd, + "expected": pattern, + "min_count": min_count, + "description": desc, + "test_id": f"ec2_{cmd.replace('show ', '').replace('-', '_')}", + } + ) # ========================================================================= # ELB Context Tests (6 tests) @@ -102,28 +110,32 @@ def generate_phase3_test_data() -> List[Dict[str, Any]]: ] for cmd, pattern, min_count, desc in elb_tests: - tests.append({ - "context": "elb", - "index": 1, - "command": cmd, - "expected": pattern, - "min_count": min_count, - "description": desc, - "test_id": f"elb_{cmd.replace('show ', '').replace('-', '_')}" - }) + tests.append( + { + "context": "elb", + "index": 1, + "command": cmd, + "expected": pattern, + "min_count": min_count, + "description": desc, + "test_id": f"elb_{cmd.replace('show ', '').replace('-', '_')}", + } + ) # ========================================================================= # Global Network Context Tests (4 tests) # ========================================================================= - tests.append({ - "context": "global-network", - "index": 1, - "command": "show core-networks", - "expected": "core-network-", - "min_count": 0, # May have 0 or more - "description": "Global Network shows core networks", - "test_id": "global_network_core_networks" - }) + tests.append( + { + "context": "global-network", + "index": 1, + "command": "show core-networks", + "expected": "core-network-", + "min_count": 0, # May have 0 or more + "description": "Global Network shows core networks", + "test_id": "global_network_core_networks", + } + ) # ========================================================================= # Core Network Context Tests (4 tests) diff --git a/tests/test_command_graph/test_top_level_commands.py b/tests/test_command_graph/test_top_level_commands.py index 4fbd448..e9acdbd 100644 --- a/tests/test_command_graph/test_top_level_commands.py +++ b/tests/test_command_graph/test_top_level_commands.py @@ -8,14 +8,11 @@ - action commands (clear, clear_cache, exit, etc.) """ -import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch from tests.test_command_graph.conftest import ( assert_success, - assert_failure, assert_output_contains, - assert_output_not_contains, assert_context_type, assert_context_stack_depth, ) @@ -377,9 +374,11 @@ def test_export_graph_command(self, command_runner, tmp_path): def test_create_routing_cache_command(self, command_runner, isolated_shell): """Test: create_routing_cache - builds routing cache.""" - with patch("aws_network_tools.modules.vpc.VPCClient"), patch( - "aws_network_tools.modules.tgw.TGWClient" - ), patch("aws_network_tools.modules.cloudwan.CloudWANClient"): + with ( + patch("aws_network_tools.modules.vpc.VPCClient"), + patch("aws_network_tools.modules.tgw.TGWClient"), + patch("aws_network_tools.modules.cloudwan.CloudWANClient"), + ): # Command uses underscore: create_routing_cache result = command_runner.run("create_routing_cache") diff --git a/tests/test_ec2_context.py b/tests/test_ec2_context.py index deadbef..4de18e1 100644 --- a/tests/test_ec2_context.py +++ b/tests/test_ec2_context.py @@ -6,19 +6,19 @@ """ import sys -import pytest from unittest.mock import MagicMock, patch from dataclasses import dataclass # Mock cmd2 BEFORE importing shell modules mock_cmd2 = MagicMock() mock_cmd2.Cmd = MagicMock -sys.modules['cmd2'] = mock_cmd2 +sys.modules["cmd2"] = mock_cmd2 @dataclass class MockContext: """Mock Context object for testing.""" + ctx_type: str ref: str label: str @@ -47,11 +47,11 @@ def test_root_show_enis_skips_when_in_ec2_context(self): "enis": [ {"id": "eni-instance-001", "private_ip": "10.0.1.100"}, ], - } + }, ) # Track if discover() gets called - with patch('aws_network_tools.modules.eni.ENIClient') as mock_eni: + with patch("aws_network_tools.modules.eni.ENIClient") as mock_eni: mock_eni.return_value.discover.return_value = [ {"id": "eni-all-001"}, {"id": "eni-all-002"}, @@ -64,8 +64,9 @@ def test_root_show_enis_skips_when_in_ec2_context(self): # The fix: discover() should NOT be called when in ec2-instance context # Before fix: This assertion fails (discover IS called via _cached) # After fix: This assertion passes (returns early) - assert not mock_self._cached.called, \ + assert not mock_self._cached.called, ( "Root _show_enis should NOT call _cached() when in ec2-instance context" + ) def test_root_show_enis_runs_when_at_root_level(self): """Root _show_enis should run normally when at root level (no context).""" @@ -74,18 +75,21 @@ def test_root_show_enis_runs_when_at_root_level(self): mock_self = MagicMock() mock_self.ctx_type = None # Root level mock_self.ctx = None - mock_self._cached = MagicMock(return_value=[ - {"id": "eni-001"}, - {"id": "eni-002"}, - ]) + mock_self._cached = MagicMock( + return_value=[ + {"id": "eni-001"}, + {"id": "eni-002"}, + ] + ) - with patch('aws_network_tools.modules.eni.ENIDisplay'): + with patch("aws_network_tools.modules.eni.ENIDisplay"): # At root level, the handler should run and use _cached RootHandlersMixin._show_enis(mock_self, None) # Should call _cached to fetch ENIs at root level - assert mock_self._cached.called, \ + assert mock_self._cached.called, ( "Root _show_enis should call _cached() at root level" + ) def test_root_show_security_groups_skips_when_in_ec2_context(self): """Bug: _show_security_groups in root.py should skip when ctx_type == 'ec2-instance'.""" @@ -99,14 +103,15 @@ def test_root_show_security_groups_skips_when_in_ec2_context(self): label="my-instance", data={ "security_groups": [{"id": "sg-instance-001"}], - } + }, ) # The fix: _cached should NOT be called in ec2-instance context RootHandlersMixin._show_security_groups(mock_self, None) - assert not mock_self._cached.called, \ + assert not mock_self._cached.called, ( "Root _show_security_groups should NOT call _cached() in ec2-instance context" + ) def test_root_show_security_groups_skips_when_in_vpc_context(self): """_show_security_groups should also skip when ctx_type == 'vpc'.""" @@ -120,13 +125,14 @@ def test_root_show_security_groups_skips_when_in_vpc_context(self): label="my-vpc", data={ "security_groups": [{"id": "sg-vpc-001"}], - } + }, ) RootHandlersMixin._show_security_groups(mock_self, None) - assert not mock_self._cached.called, \ + assert not mock_self._cached.called, ( "Root _show_security_groups should NOT call _cached() in vpc context" + ) def test_root_show_security_groups_runs_when_at_root_level(self): """Root _show_security_groups should run normally at root level.""" @@ -135,18 +141,21 @@ def test_root_show_security_groups_runs_when_at_root_level(self): mock_self = MagicMock() mock_self.ctx_type = None # Root level mock_self.ctx = None - mock_self._cached = MagicMock(return_value={ - "unused_groups": [], - "risky_rules": [], - "nacl_issues": [], - }) + mock_self._cached = MagicMock( + return_value={ + "unused_groups": [], + "risky_rules": [], + "nacl_issues": [], + } + ) - with patch('aws_network_tools.modules.security.SecurityDisplay'): + with patch("aws_network_tools.modules.security.SecurityDisplay"): RootHandlersMixin._show_security_groups(mock_self, None) # Should call _cached at root level - assert mock_self._cached.called, \ + assert mock_self._cached.called, ( "Root _show_security_groups should call _cached() at root level" + ) class TestMRODocumentation: @@ -158,7 +167,7 @@ def test_document_mro_conflict(self): from aws_network_tools.shell.handlers.ec2 import EC2HandlersMixin # Both define _show_enis - this is the source of the conflict - assert hasattr(RootHandlersMixin, '_show_enis'), "Root has _show_enis" - assert hasattr(EC2HandlersMixin, '_show_enis'), "EC2 has _show_enis" + assert hasattr(RootHandlersMixin, "_show_enis"), "Root has _show_enis" + assert hasattr(EC2HandlersMixin, "_show_enis"), "EC2 has _show_enis" # The fix makes root handler context-aware, so MRO conflict is resolved diff --git a/tests/test_elb_handler.py b/tests/test_elb_handler.py index e2125ca..5492695 100644 --- a/tests/test_elb_handler.py +++ b/tests/test_elb_handler.py @@ -1,8 +1,6 @@ """Tests for ELB handler - Issue #10: Verify handler correctly displays data.""" import pytest -from unittest.mock import MagicMock, patch -from io import StringIO @pytest.fixture @@ -23,7 +21,11 @@ def mock_elb_detail(): "arn": "arn:aws:elasticloadbalancing:eu-west-1:123456789:listener/app/test-alb/abc123/def456", "port": 443, "protocol": "HTTPS", - "ssl_certs": [{"CertificateArn": "arn:aws:acm:eu-west-1:123456789:certificate/cert123"}], + "ssl_certs": [ + { + "CertificateArn": "arn:aws:acm:eu-west-1:123456789:certificate/cert123" + } + ], "default_actions": [ { "type": "forward", @@ -36,8 +38,16 @@ def mock_elb_detail(): "vpc_id": "vpc-123", "target_type": "instance", "targets": [ - {"id": "i-1234567890abcdef0", "port": 8080, "state": "healthy"}, - {"id": "i-0987654321fedcba0", "port": 8080, "state": "healthy"}, + { + "id": "i-1234567890abcdef0", + "port": 8080, + "state": "healthy", + }, + { + "id": "i-0987654321fedcba0", + "port": 8080, + "state": "healthy", + }, ], }, } @@ -102,16 +112,18 @@ def __init__(self): self.output_format = "table" shell = MockShell() - + # Verify listeners exist in context data listeners = shell.ctx.data.get("listeners", []) assert len(listeners) == 2, f"Expected 2 listeners, got {len(listeners)}" - + # Verify listener fields are accessible for listener in listeners: assert "port" in listener, "Listener should have 'port'" assert "protocol" in listener, "Listener should have 'protocol'" - assert "default_actions" in listener, "Listener should have 'default_actions'" + assert "default_actions" in listener, ( + "Listener should have 'default_actions'" + ) def test_show_targets_displays_data(self, mock_elb_detail): """Test that show targets displays target group data correctly.""" @@ -125,11 +137,11 @@ def __init__(self): self.output_format = "table" shell = MockShell() - + # Verify target_groups exist at top level targets = shell.ctx.data.get("target_groups", []) assert len(targets) > 0, "target_groups should not be empty" - + # Verify target group fields for tg in targets: assert "name" in tg, "Target group should have 'name'" @@ -149,11 +161,11 @@ def __init__(self): self.output_format = "table" shell = MockShell() - + # Verify target_health exists at top level health = shell.ctx.data.get("target_health", []) assert len(health) > 0, "target_health should not be empty" - + # Verify health fields for h in health: assert "id" in h, "Health should have 'id'" @@ -163,12 +175,12 @@ def __init__(self): def test_handler_processes_default_actions_list(self, mock_elb_detail): """Test that handler correctly processes default_actions as a list.""" listeners = mock_elb_detail["listeners"] - + # HTTPS listener should have forward action - https_listener = [l for l in listeners if l["port"] == 443][0] + https_listener = [lis for lis in listeners if lis["port"] == 443][0] actions = https_listener.get("default_actions", []) assert len(actions) > 0, "HTTPS listener should have default_actions" - + forward_action = actions[0] assert forward_action["type"] == "forward" assert forward_action.get("target_group") is not None @@ -177,7 +189,7 @@ def test_handler_processes_default_actions_list(self, mock_elb_detail): def test_handler_uses_id_field_for_health(self, mock_elb_detail): """Test that handler uses 'id' field for target health display.""" health = mock_elb_detail["target_health"] - + for h in health: # Handler should use 'id' field assert "id" in h, "Health entry should have 'id' field" @@ -200,7 +212,7 @@ def __init__(self): shell = MockShell() shell._show_listeners(None) - + # Rich console output goes to stdout # We can't easily capture Rich output, but we can verify no exception diff --git a/tests/test_elb_module.py b/tests/test_elb_module.py index 773d7f5..71c8ba7 100644 --- a/tests/test_elb_module.py +++ b/tests/test_elb_module.py @@ -5,7 +5,7 @@ """ import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock @pytest.fixture @@ -24,7 +24,10 @@ def mock_elbv2_client(): "Scheme": "internet-facing", "VpcId": "vpc-123", "State": {"Code": "active"}, - "AvailabilityZones": [{"ZoneName": "eu-west-1a"}, {"ZoneName": "eu-west-1b"}], + "AvailabilityZones": [ + {"ZoneName": "eu-west-1a"}, + {"ZoneName": "eu-west-1b"}, + ], } ] } @@ -37,7 +40,11 @@ def mock_elbv2_client(): "LoadBalancerArn": "arn:aws:elasticloadbalancing:eu-west-1:123456789:loadbalancer/app/test-alb/abc123", "Port": 443, "Protocol": "HTTPS", - "Certificates": [{"CertificateArn": "arn:aws:acm:eu-west-1:123456789:certificate/cert123"}], + "Certificates": [ + { + "CertificateArn": "arn:aws:acm:eu-west-1:123456789:certificate/cert123" + } + ], "DefaultActions": [ { "Type": "forward", @@ -55,7 +62,11 @@ def mock_elbv2_client(): "DefaultActions": [ { "Type": "redirect", - "RedirectConfig": {"Protocol": "HTTPS", "Port": "443", "StatusCode": "HTTP_301"}, + "RedirectConfig": { + "Protocol": "HTTPS", + "Port": "443", + "StatusCode": "HTTP_301", + }, "Order": 1, } ], @@ -113,7 +124,9 @@ def mock_boto_session(mock_elbv2_client): class TestELBModuleIssue10: """Tests for Issue #10: ELB commands return no output.""" - def test_listener_loop_processes_all_items(self, mock_boto_session, mock_elbv2_client): + def test_listener_loop_processes_all_items( + self, mock_boto_session, mock_elbv2_client + ): """Bug: Variable shadowing causes loop to lose original listener data. The loop variable 'listener' is reassigned inside the loop (line 163), @@ -124,15 +137,17 @@ def test_listener_loop_processes_all_items(self, mock_boto_session, mock_elbv2_c elb_client = ELBClient(session=mock_boto_session) result = elb_client.get_elb_detail( "arn:aws:elasticloadbalancing:eu-west-1:123456789:loadbalancer/app/test-alb/abc123", - "eu-west-1" + "eu-west-1", ) # Should have processed BOTH listeners (port 443 and 80) assert "listeners" in result, "Result should contain 'listeners' key" - assert len(result["listeners"]) == 2, f"Expected 2 listeners, got {len(result.get('listeners', []))}" + assert len(result["listeners"]) == 2, ( + f"Expected 2 listeners, got {len(result.get('listeners', []))}" + ) # Verify listener data was extracted correctly (not shadowed) - ports = [l.get("port") for l in result["listeners"]] + ports = [lis.get("port") for lis in result["listeners"]] assert 443 in ports, "Port 443 listener should be present" assert 80 in ports, "Port 80 listener should be present" @@ -147,13 +162,17 @@ def test_target_groups_at_top_level(self, mock_boto_session, mock_elbv2_client): elb_client = ELBClient(session=mock_boto_session) result = elb_client.get_elb_detail( "arn:aws:elasticloadbalancing:eu-west-1:123456789:loadbalancer/app/test-alb/abc123", - "eu-west-1" + "eu-west-1", ) # target_groups MUST be at top level - assert "target_groups" in result, "Result must have 'target_groups' at top level" + assert "target_groups" in result, ( + "Result must have 'target_groups' at top level" + ) assert isinstance(result["target_groups"], list), "target_groups must be a list" - assert len(result["target_groups"]) > 0, "target_groups should not be empty for ALB with target groups" + assert len(result["target_groups"]) > 0, ( + "target_groups should not be empty for ALB with target groups" + ) def test_target_health_at_top_level(self, mock_boto_session, mock_elbv2_client): """Bug: Handler expects target_health at top level but module doesn't provide it. @@ -166,14 +185,18 @@ def test_target_health_at_top_level(self, mock_boto_session, mock_elbv2_client): elb_client = ELBClient(session=mock_boto_session) result = elb_client.get_elb_detail( "arn:aws:elasticloadbalancing:eu-west-1:123456789:loadbalancer/app/test-alb/abc123", - "eu-west-1" + "eu-west-1", ) # target_health MUST be at top level - assert "target_health" in result, "Result must have 'target_health' at top level" + assert "target_health" in result, ( + "Result must have 'target_health' at top level" + ) assert isinstance(result["target_health"], list), "target_health must be a list" - def test_empty_listeners_returns_empty_structures(self, mock_boto_session, mock_elbv2_client): + def test_empty_listeners_returns_empty_structures( + self, mock_boto_session, mock_elbv2_client + ): """Edge case: ELB with no listeners should return empty lists, not None.""" from aws_network_tools.modules.elb import ELBClient @@ -183,26 +206,34 @@ def test_empty_listeners_returns_empty_structures(self, mock_boto_session, mock_ elb_client = ELBClient(session=mock_boto_session) result = elb_client.get_elb_detail( "arn:aws:elasticloadbalancing:eu-west-1:123456789:loadbalancer/app/test-alb/abc123", - "eu-west-1" + "eu-west-1", ) # Should return empty lists, not None or missing keys assert result.get("listeners") == [], "Empty listeners should return []" - assert result.get("target_groups") == [], "No listeners means no target_groups, should return []" - assert result.get("target_health") == [], "No listeners means no target_health, should return []" + assert result.get("target_groups") == [], ( + "No listeners means no target_groups, should return []" + ) + assert result.get("target_health") == [], ( + "No listeners means no target_health, should return []" + ) - def test_listener_with_forward_action_extracts_target_group(self, mock_boto_session, mock_elbv2_client): + def test_listener_with_forward_action_extracts_target_group( + self, mock_boto_session, mock_elbv2_client + ): """Test that forward actions properly extract target group ARNs.""" from aws_network_tools.modules.elb import ELBClient elb_client = ELBClient(session=mock_boto_session) result = elb_client.get_elb_detail( "arn:aws:elasticloadbalancing:eu-west-1:123456789:loadbalancer/app/test-alb/abc123", - "eu-west-1" + "eu-west-1", ) # Find the HTTPS listener (port 443) which has a forward action - https_listeners = [l for l in result.get("listeners", []) if l.get("port") == 443] + https_listeners = [ + lis for lis in result.get("listeners", []) if lis.get("port") == 443 + ] assert len(https_listeners) == 1, "Should have one HTTPS listener" https_listener = https_listeners[0] @@ -210,4 +241,6 @@ def test_listener_with_forward_action_extracts_target_group(self, mock_boto_sess actions = https_listener.get("default_actions", []) forward_actions = [a for a in actions if a.get("type") == "forward"] assert len(forward_actions) > 0, "HTTPS listener should have forward action" - assert forward_actions[0].get("target_group_arn") is not None, "Forward action should have target_group_arn" + assert forward_actions[0].get("target_group_arn") is not None, ( + "Forward action should have target_group_arn" + ) diff --git a/tests/test_graph_commands.py b/tests/test_graph_commands.py index 5260841..785967f 100644 --- a/tests/test_graph_commands.py +++ b/tests/test_graph_commands.py @@ -2,46 +2,59 @@ Uses command graph + fixtures to ensure complete coverage. """ + import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch from aws_network_tools.shell.base import HIERARCHY from tests.fixtures.command_fixtures import ( - COMMAND_MOCKS, COMMAND_DEPENDENCIES, get_mock_for_command, get_dependencies, - _vpcs_list, _tgws_list, _firewalls_list, _ec2_instances_list, _elbs_list, _vpns_list, + COMMAND_MOCKS, + get_mock_for_command, + _vpcs_list, + _tgws_list, + _firewalls_list, + _ec2_instances_list, + _elbs_list, + _vpns_list, ) -from tests.fixtures import get_vpc_detail, get_tgw_detail, get_elb_detail, get_vpn_detail +from tests.fixtures import get_vpc_detail def build_command_graph() -> dict: """Build complete command graph from HIERARCHY.""" graph = {} - + for context, commands in HIERARCHY.items(): prefix = f"{context}." if context else "" - + for show_cmd in commands.get("show", []): cmd_key = f"{prefix}show {show_cmd}" graph[cmd_key] = { - "type": "show", "context": context, - "command": f"show {show_cmd}", "full_key": cmd_key, + "type": "show", + "context": context, + "command": f"show {show_cmd}", + "full_key": cmd_key, } - + for set_cmd in commands.get("set", []): cmd_key = f"{prefix}set {set_cmd}" graph[cmd_key] = { - "type": "set", "context": context, - "command": f"set {set_cmd}", "full_key": cmd_key, + "type": "set", + "context": context, + "command": f"set {set_cmd}", + "full_key": cmd_key, "enters_context": set_cmd, } - + for action in commands.get("commands", []): if action not in ("show", "set", "exit", "end", "clear"): cmd_key = f"{prefix}{action}" graph[cmd_key] = { - "type": "action", "context": context, - "command": action, "full_key": cmd_key, + "type": "action", + "context": context, + "command": action, + "full_key": cmd_key, } - + return graph @@ -51,43 +64,51 @@ def build_command_graph() -> dict: class TestCommandGraph: """Test command graph structure.""" - + def test_graph_has_all_contexts(self): contexts = {info["context"] for info in COMMAND_GRAPH.values()} assert contexts == set(HIERARCHY.keys()) - + def test_graph_command_count(self): show_count = sum(1 for v in COMMAND_GRAPH.values() if v["type"] == "show") set_count = sum(1 for v in COMMAND_GRAPH.values() if v["type"] == "set") assert show_count >= 50 assert set_count >= 10 - + def test_fixture_coverage(self): total = len(COMMAND_GRAPH) covered = len(TESTABLE_COMMANDS) - print(f"\nFixture coverage: {covered}/{total} ({(covered/total)*100:.1f}%)") + print(f"\nFixture coverage: {covered}/{total} ({(covered / total) * 100:.1f}%)") class TestRootShowCommands: """Test root-level show commands.""" - + @pytest.fixture def shell(self): from aws_network_tools.shell import AWSNetShell + shell = AWSNetShell() shell.no_cache = True yield shell shell._cache.clear() - - @pytest.mark.parametrize("command", [ - "show vpcs", "show transit_gateways", "show firewalls", - "show ec2-instances", "show elbs", "show vpns", - ]) + + @pytest.mark.parametrize( + "command", + [ + "show vpcs", + "show transit_gateways", + "show firewalls", + "show ec2-instances", + "show elbs", + "show vpns", + ], + ) def test_root_show_command(self, shell, command): mock_config = get_mock_for_command(command) if not mock_config: pytest.skip(f"No fixture for {command}") - + with patch(mock_config["target"], return_value=mock_config["return_value"]): result = shell.onecmd(command) assert result in (None, False) @@ -95,79 +116,111 @@ def test_root_show_command(self, shell, command): class TestContextEntry: """Test context entry (set) commands with full mock chain.""" - + @pytest.fixture def shell(self): from aws_network_tools.shell import AWSNetShell + shell = AWSNetShell() shell.no_cache = True yield shell shell._cache.clear() shell.context_stack.clear() - + def test_vpc_context_entry(self, shell): """Test entering VPC context.""" vpcs = _vpcs_list() vpc_id = vpcs[0]["id"] if vpcs else "vpc-test123" vpc_detail = get_vpc_detail(vpc_id) or { - "id": vpc_id, "name": "test", "region": "eu-west-1", - "cidr": "10.0.0.0/16", "cidrs": ["10.0.0.0/16"], "state": "available", - "route_tables": [], "security_groups": [], "nacls": [] + "id": vpc_id, + "name": "test", + "region": "eu-west-1", + "cidr": "10.0.0.0/16", + "cidrs": ["10.0.0.0/16"], + "state": "available", + "route_tables": [], + "security_groups": [], + "nacls": [], } - - with patch("aws_network_tools.modules.vpc.VPCClient.discover", return_value=vpcs): - with patch("aws_network_tools.modules.vpc.VPCClient.get_vpc_detail", return_value=vpc_detail): + + with patch( + "aws_network_tools.modules.vpc.VPCClient.discover", return_value=vpcs + ): + with patch( + "aws_network_tools.modules.vpc.VPCClient.get_vpc_detail", + return_value=vpc_detail, + ): shell.onecmd("show vpcs") result = shell.onecmd("set vpc 1") assert result in (None, False) assert shell.context_stack # Context was entered - + def test_tgw_context_entry(self, shell): """Test entering TGW context.""" tgws = _tgws_list() - - with patch("aws_network_tools.modules.tgw.TGWClient.discover", return_value=tgws): + + with patch( + "aws_network_tools.modules.tgw.TGWClient.discover", return_value=tgws + ): shell.onecmd("show transit_gateways") result = shell.onecmd("set transit-gateway 1") assert result in (None, False) - + def test_firewall_context_entry(self, shell): """Test entering firewall context.""" fws = _firewalls_list() - - with patch("aws_network_tools.modules.anfw.ANFWClient.discover", return_value=fws): + + with patch( + "aws_network_tools.modules.anfw.ANFWClient.discover", return_value=fws + ): shell.onecmd("show firewalls") result = shell.onecmd("set firewall 1") assert result in (None, False) - + def test_ec2_context_entry(self, shell): """Test entering EC2 instance context.""" instances = _ec2_instances_list() - - with patch("aws_network_tools.modules.ec2.EC2Client.discover", return_value=instances): + + with patch( + "aws_network_tools.modules.ec2.EC2Client.discover", return_value=instances + ): shell.onecmd("show ec2-instances") result = shell.onecmd("set ec2-instance 1") assert result in (None, False) - + def test_elb_context_entry(self, shell): """Test entering ELB context.""" elbs = _elbs_list() - elb_detail = { - "arn": elbs[0]["arn"], "name": elbs[0]["name"], "type": elbs[0]["type"], - "listeners": [], "target_groups": [] - } if elbs else {} - - with patch("aws_network_tools.modules.elb.ELBClient.discover", return_value=elbs): - with patch("aws_network_tools.modules.elb.ELBClient.get_elb_detail", return_value=elb_detail): + elb_detail = ( + { + "arn": elbs[0]["arn"], + "name": elbs[0]["name"], + "type": elbs[0]["type"], + "listeners": [], + "target_groups": [], + } + if elbs + else {} + ) + + with patch( + "aws_network_tools.modules.elb.ELBClient.discover", return_value=elbs + ): + with patch( + "aws_network_tools.modules.elb.ELBClient.get_elb_detail", + return_value=elb_detail, + ): shell.onecmd("show elbs") result = shell.onecmd("set elb 1") assert result in (None, False) - + def test_vpn_context_entry(self, shell): """Test entering VPN context.""" vpns = _vpns_list() - - with patch("aws_network_tools.modules.vpn.VPNClient.discover", return_value=vpns): + + with patch( + "aws_network_tools.modules.vpn.VPNClient.discover", return_value=vpns + ): shell.onecmd("show vpns") result = shell.onecmd("set vpn 1") assert result in (None, False) @@ -175,45 +228,87 @@ def test_vpn_context_entry(self, shell): class TestVPCContextCommands: """Test commands within VPC context.""" - + @pytest.fixture def shell_in_vpc(self): """Shell with VPC context entered.""" from aws_network_tools.shell import AWSNetShell + shell = AWSNetShell() shell.no_cache = True - + vpcs = _vpcs_list() vpc_id = vpcs[0]["id"] if vpcs else "vpc-test" vpc_detail = get_vpc_detail(vpc_id) or { - "id": vpc_id, "name": "test", "region": "eu-west-1", - "cidr": "10.0.0.0/16", "cidrs": ["10.0.0.0/16"], "state": "available", - "route_tables": [{"id": "rtb-1", "name": "main", "is_main": True, "routes": [], "subnets": []}], - "security_groups": [{"id": "sg-1", "name": "default", "description": "", "ingress": [], "egress": []}], - "nacls": [{"id": "acl-1", "name": "default", "is_default": True, "entries": []}], - "subnets": [{"id": "subnet-1", "name": "public", "cidr": "10.0.1.0/24", "az": "eu-west-1a", "state": "available"}], + "id": vpc_id, + "name": "test", + "region": "eu-west-1", + "cidr": "10.0.0.0/16", + "cidrs": ["10.0.0.0/16"], + "state": "available", + "route_tables": [ + { + "id": "rtb-1", + "name": "main", + "is_main": True, + "routes": [], + "subnets": [], + } + ], + "security_groups": [ + { + "id": "sg-1", + "name": "default", + "description": "", + "ingress": [], + "egress": [], + } + ], + "nacls": [ + {"id": "acl-1", "name": "default", "is_default": True, "entries": []} + ], + "subnets": [ + { + "id": "subnet-1", + "name": "public", + "cidr": "10.0.1.0/24", + "az": "eu-west-1a", + "state": "available", + } + ], } - + # Start patches - p1 = patch("aws_network_tools.modules.vpc.VPCClient.discover", return_value=vpcs) - p2 = patch("aws_network_tools.modules.vpc.VPCClient.get_vpc_detail", return_value=vpc_detail) + p1 = patch( + "aws_network_tools.modules.vpc.VPCClient.discover", return_value=vpcs + ) + p2 = patch( + "aws_network_tools.modules.vpc.VPCClient.get_vpc_detail", + return_value=vpc_detail, + ) p1.start() p2.start() - + shell.onecmd("show vpcs") shell.onecmd("set vpc 1") - + yield shell, vpc_detail - + p1.stop() p2.stop() shell._cache.clear() shell.context_stack.clear() - - @pytest.mark.parametrize("command", [ - "show detail", "show subnets", "show route-tables", - "show security-groups", "show nacls", - ]) + + @pytest.mark.parametrize( + "command", + [ + "show detail", + "show subnets", + "show route-tables", + "show security-groups", + "show nacls", + ], + ) def test_vpc_show_commands(self, shell_in_vpc, command): """Test VPC context show commands.""" shell, _ = shell_in_vpc @@ -223,13 +318,14 @@ def test_vpc_show_commands(self, shell_in_vpc, command): class TestTGWContextCommands: """Test commands within TGW context.""" - + @pytest.fixture def shell_in_tgw(self): from aws_network_tools.shell import AWSNetShell + shell = AWSNetShell() shell.no_cache = True - + tgws = _tgws_list() # Add route_tables to first TGW if tgws: @@ -237,22 +333,31 @@ def shell_in_tgw(self): {"id": "tgw-rtb-1", "name": "main", "routes": [], "state": "available"} ] tgws[0]["attachments"] = [ - {"id": "tgw-attach-1", "type": "vpc", "resource_id": "vpc-123", "state": "available"} + { + "id": "tgw-attach-1", + "type": "vpc", + "resource_id": "vpc-123", + "state": "available", + } ] - - p1 = patch("aws_network_tools.modules.tgw.TGWClient.discover", return_value=tgws) + + p1 = patch( + "aws_network_tools.modules.tgw.TGWClient.discover", return_value=tgws + ) p1.start() - + shell.onecmd("show transit_gateways") shell.onecmd("set transit-gateway 1") - + yield shell - + p1.stop() shell._cache.clear() shell.context_stack.clear() - - @pytest.mark.parametrize("command", ["show detail", "show route-tables", "show attachments"]) + + @pytest.mark.parametrize( + "command", ["show detail", "show route-tables", "show attachments"] + ) def test_tgw_show_commands(self, shell_in_tgw, command): shell = shell_in_tgw result = shell.onecmd(command) @@ -261,30 +366,37 @@ def test_tgw_show_commands(self, shell_in_tgw, command): class TestFirewallContextCommands: """Test commands within Firewall context.""" - + @pytest.fixture def shell_in_firewall(self): from aws_network_tools.shell import AWSNetShell + shell = AWSNetShell() shell.no_cache = True - + fws = _firewalls_list() if fws: - fws[0]["rule_groups"] = [{"name": "test-rg", "arn": "arn:...", "type": "STATEFUL"}] - - p1 = patch("aws_network_tools.modules.anfw.ANFWClient.discover", return_value=fws) + fws[0]["rule_groups"] = [ + {"name": "test-rg", "arn": "arn:...", "type": "STATEFUL"} + ] + + p1 = patch( + "aws_network_tools.modules.anfw.ANFWClient.discover", return_value=fws + ) p1.start() - + shell.onecmd("show firewalls") shell.onecmd("set firewall 1") - + yield shell - + p1.stop() shell._cache.clear() shell.context_stack.clear() - - @pytest.mark.parametrize("command", ["show detail", "show rule-groups", "show policy"]) + + @pytest.mark.parametrize( + "command", ["show detail", "show rule-groups", "show policy"] + ) def test_firewall_show_commands(self, shell_in_firewall, command): result = shell_in_firewall.onecmd(command) assert result in (None, False) @@ -292,31 +404,36 @@ def test_firewall_show_commands(self, shell_in_firewall, command): class TestEC2ContextCommands: """Test commands within EC2 instance context.""" - + @pytest.fixture def shell_in_ec2(self): from aws_network_tools.shell import AWSNetShell + shell = AWSNetShell() shell.no_cache = True - + instances = _ec2_instances_list() if instances: instances[0]["security_groups"] = [{"id": "sg-1", "name": "default"}] instances[0]["enis"] = [{"id": "eni-1", "private_ip": "10.0.1.10"}] - - p1 = patch("aws_network_tools.modules.ec2.EC2Client.discover", return_value=instances) + + p1 = patch( + "aws_network_tools.modules.ec2.EC2Client.discover", return_value=instances + ) p1.start() - + shell.onecmd("show ec2-instances") shell.onecmd("set ec2-instance 1") - + yield shell - + p1.stop() shell._cache.clear() shell.context_stack.clear() - - @pytest.mark.parametrize("command", ["show detail", "show security-groups", "show enis"]) + + @pytest.mark.parametrize( + "command", ["show detail", "show security-groups", "show enis"] + ) def test_ec2_show_commands(self, shell_in_ec2, command): result = shell_in_ec2.onecmd(command) assert result in (None, False) @@ -324,13 +441,14 @@ def test_ec2_show_commands(self, shell_in_ec2, command): class TestELBContextCommands: """Test commands within ELB context.""" - + @pytest.fixture def shell_in_elb(self): from aws_network_tools.shell import AWSNetShell + shell = AWSNetShell() shell.no_cache = True - + elbs = _elbs_list() elb_detail = { "arn": elbs[0]["arn"] if elbs else "arn:...", @@ -339,23 +457,30 @@ def shell_in_elb(self): "listeners": [{"port": 443, "protocol": "HTTPS"}], "target_groups": [{"name": "tg-1", "targets": []}], } - - p1 = patch("aws_network_tools.modules.elb.ELBClient.discover", return_value=elbs) - p2 = patch("aws_network_tools.modules.elb.ELBClient.get_elb_detail", return_value=elb_detail) + + p1 = patch( + "aws_network_tools.modules.elb.ELBClient.discover", return_value=elbs + ) + p2 = patch( + "aws_network_tools.modules.elb.ELBClient.get_elb_detail", + return_value=elb_detail, + ) p1.start() p2.start() - + shell.onecmd("show elbs") shell.onecmd("set elb 1") - + yield shell - + p1.stop() p2.stop() shell._cache.clear() shell.context_stack.clear() - - @pytest.mark.parametrize("command", ["show detail", "show listeners", "show targets"]) + + @pytest.mark.parametrize( + "command", ["show detail", "show listeners", "show targets"] + ) def test_elb_show_commands(self, shell_in_elb, command): result = shell_in_elb.onecmd(command) assert result in (None, False) @@ -363,32 +488,35 @@ def test_elb_show_commands(self, shell_in_elb, command): class TestVPNContextCommands: """Test commands within VPN context.""" - + @pytest.fixture def shell_in_vpn(self): from aws_network_tools.shell import AWSNetShell + shell = AWSNetShell() shell.no_cache = True - + vpns = _vpns_list() if vpns: vpns[0]["tunnels"] = [ {"outside_ip": "203.0.113.1", "status": "UP"}, {"outside_ip": "203.0.113.2", "status": "UP"}, ] - - p1 = patch("aws_network_tools.modules.vpn.VPNClient.discover", return_value=vpns) + + p1 = patch( + "aws_network_tools.modules.vpn.VPNClient.discover", return_value=vpns + ) p1.start() - + shell.onecmd("show vpns") shell.onecmd("set vpn 1") - + yield shell - + p1.stop() shell._cache.clear() shell.context_stack.clear() - + @pytest.mark.parametrize("command", ["show detail", "show tunnels"]) def test_vpn_show_commands(self, shell_in_vpn, command): result = shell_in_vpn.onecmd(command) diff --git a/tests/test_issue_2_show_detail.py b/tests/test_issue_2_show_detail.py index 9760c60..e82b394 100644 --- a/tests/test_issue_2_show_detail.py +++ b/tests/test_issue_2_show_detail.py @@ -7,8 +7,7 @@ """ import pytest -from unittest.mock import MagicMock, patch -from io import StringIO +from unittest.mock import MagicMock from aws_network_tools.modules.cloudwan import CloudWANDisplay @@ -36,9 +35,9 @@ def test_show_detail_with_complete_data(self): "name": "production", "region": "us-east-1", "type": "segment", - "routes": [] + "routes": [], } - ] + ], } # Should NOT raise KeyError @@ -59,7 +58,7 @@ def test_show_detail_missing_global_network_name(self): "regions": ["us-west-2"], "segments": ["shared"], "nfgs": [], - "route_tables": [] + "route_tables": [], } # Should NOT raise KeyError - this is the bug fix verification @@ -82,7 +81,7 @@ def test_show_detail_empty_global_network_name(self): "regions": [], "segments": [], "nfgs": [], - "route_tables": [] + "route_tables": [], } # Should NOT raise any errors @@ -98,7 +97,7 @@ def test_show_detail_none_global_network_name(self): "regions": ["ap-southeast-1"], "segments": ["transit"], "nfgs": [], - "route_tables": [] + "route_tables": [], } # Should NOT raise any errors @@ -161,7 +160,7 @@ def test_show_list_with_missing_global_network_name(self): # Missing global_network_name "regions": ["us-east-1"], "segments": ["seg1"], - "route_tables": [] + "route_tables": [], } ] diff --git a/tests/test_issue_5_tgw_rt_details.py b/tests/test_issue_5_tgw_rt_details.py index d74c226..7301f03 100644 --- a/tests/test_issue_5_tgw_rt_details.py +++ b/tests/test_issue_5_tgw_rt_details.py @@ -1,4 +1,3 @@ -import pytest from unittest.mock import MagicMock from aws_network_tools.modules.tgw import TGWClient @@ -11,55 +10,72 @@ def test_discover_fetches_associations_and_propagations(self): mock_session.region_name = "us-east-1" mock_client.describe_transit_gateways.return_value = { - "TransitGateways": [{ - "TransitGatewayId": "tgw-123", - "State": "available", - "Tags": [{"Key": "Name", "Value": "test-tgw"}], - }] + "TransitGateways": [ + { + "TransitGatewayId": "tgw-123", + "State": "available", + "Tags": [{"Key": "Name", "Value": "test-tgw"}], + } + ] } mock_client.describe_transit_gateway_attachments.return_value = { - "TransitGatewayAttachments": [{ - "TransitGatewayAttachmentId": "tgw-att-abc", - "ResourceId": "vpc-123", - "ResourceType": "vpc", - "State": "available", - "Tags": [{"Key": "Name", "Value": "vpc-att"}], - }] + "TransitGatewayAttachments": [ + { + "TransitGatewayAttachmentId": "tgw-att-abc", + "ResourceId": "vpc-123", + "ResourceType": "vpc", + "State": "available", + "Tags": [{"Key": "Name", "Value": "vpc-att"}], + } + ] } mock_client.describe_transit_gateway_route_tables.return_value = { - "TransitGatewayRouteTables": [{ - "TransitGatewayRouteTableId": "tgw-rtb-456", - "Tags": [{"Key": "Name", "Value": "test-rtb"}], - }] + "TransitGatewayRouteTables": [ + { + "TransitGatewayRouteTableId": "tgw-rtb-456", + "Tags": [{"Key": "Name", "Value": "test-rtb"}], + } + ] } mock_client.search_transit_gateway_routes.return_value = { - "Routes": [{ - "DestinationCidrBlock": "10.0.0.0/16", - "State": "active", - "Type": "propagated", - "TransitGatewayAttachments": [{"TransitGatewayAttachmentId": "tgw-att-abc", "ResourceType": "vpc"}] - }] + "Routes": [ + { + "DestinationCidrBlock": "10.0.0.0/16", + "State": "active", + "Type": "propagated", + "TransitGatewayAttachments": [ + { + "TransitGatewayAttachmentId": "tgw-att-abc", + "ResourceType": "vpc", + } + ], + } + ] } mock_client.get_transit_gateway_route_table_associations.return_value = { - "Associations": [{ - "TransitGatewayAttachmentId": "tgw-att-abc", - "ResourceId": "vpc-123", - "ResourceType": "vpc", - "State": "associated" - }] + "Associations": [ + { + "TransitGatewayAttachmentId": "tgw-att-abc", + "ResourceId": "vpc-123", + "ResourceType": "vpc", + "State": "associated", + } + ] } mock_client.get_transit_gateway_route_table_propagations.return_value = { - "TransitGatewayRouteTablePropagations": [{ - "TransitGatewayAttachmentId": "tgw-att-def", - "ResourceId": "vpn-456", - "ResourceType": "vpn", - "State": "enabled" - }] + "TransitGatewayRouteTablePropagations": [ + { + "TransitGatewayAttachmentId": "tgw-att-def", + "ResourceId": "vpn-456", + "ResourceType": "vpn", + "State": "enabled", + } + ] } client = TGWClient(session=mock_session) diff --git a/tests/test_issue_8_vpc_set.py b/tests/test_issue_8_vpc_set.py index 87909a1..7db40b8 100644 --- a/tests/test_issue_8_vpc_set.py +++ b/tests/test_issue_8_vpc_set.py @@ -19,40 +19,48 @@ class TestIssue8VpcSet: def test_resolve_with_none_name_by_id(self): """Test _resolve handles items where name is None when searching by ID. - + This is the exact scenario from issue #8: - VPC has no Name tag, so name=None - User tries to set vpc by VPC ID - Previously failed with: AttributeError: 'NoneType' object has no attribute 'lower' """ shell = AWSNetShellBase() - + # Simulate VPCs where some have name=None (no Name tag) items = [ {"id": "vpc-0ed63bf689aae45cf", "name": None, "region": "ap-northeast-1"}, - {"id": "vpc-019dd8e2dee602c9c", "name": "DC-AB2-vpc", "region": "us-east-2"}, + { + "id": "vpc-019dd8e2dee602c9c", + "name": "DC-AB2-vpc", + "region": "us-east-2", + }, {"id": "vpc-0f0e196d10a1854c8", "name": None, "region": "us-east-1"}, ] - + # This was the failing case - resolving by VPC ID when name is None result = shell._resolve(items, "vpc-0ed63bf689aae45cf") assert result is not None, "Should resolve VPC by ID even when name is None" assert result["id"] == "vpc-0ed63bf689aae45cf" - + def test_resolve_with_none_name_by_name_search(self): """Test _resolve doesn't crash when searching by name with None names in list.""" shell = AWSNetShellBase() - + items = [ {"id": "vpc-0ed63bf689aae45cf", "name": None, "region": "ap-northeast-1"}, - {"id": "vpc-019dd8e2dee602c9c", "name": "DC-AB2-vpc", "region": "us-east-2"}, + { + "id": "vpc-019dd8e2dee602c9c", + "name": "DC-AB2-vpc", + "region": "us-east-2", + }, ] - + # Search by name - should not crash even with None names in list result = shell._resolve(items, "DC-AB2-vpc") assert result is not None assert result["id"] == "vpc-019dd8e2dee602c9c" - + def test_resolve_by_id(self): """Test _resolve can find item by ID.""" shell = AWSNetShellBase() @@ -60,11 +68,11 @@ def test_resolve_by_id(self): {"id": "vpc-abc123", "name": "test-vpc"}, {"id": "vpc-def456", "name": "other-vpc"}, ] - + result = shell._resolve(items, "vpc-def456") assert result is not None assert result["id"] == "vpc-def456" - + def test_resolve_by_name_case_insensitive(self): """Test _resolve can find item by name (case-insensitive).""" shell = AWSNetShellBase() @@ -72,11 +80,11 @@ def test_resolve_by_name_case_insensitive(self): {"id": "vpc-abc123", "name": "Test-VPC"}, {"id": "vpc-def456", "name": "Other-VPC"}, ] - + result = shell._resolve(items, "test-vpc") assert result is not None assert result["id"] == "vpc-abc123" - + def test_resolve_by_index(self): """Test _resolve can find item by index.""" shell = AWSNetShellBase() @@ -84,21 +92,21 @@ def test_resolve_by_index(self): {"id": "vpc-abc123", "name": "test-vpc"}, {"id": "vpc-def456", "name": "other-vpc"}, ] - + result = shell._resolve(items, "2") assert result is not None assert result["id"] == "vpc-def456" - + def test_resolve_not_found(self): """Test _resolve returns None when not found.""" shell = AWSNetShellBase() items = [ {"id": "vpc-abc123", "name": "test-vpc"}, ] - + result = shell._resolve(items, "vpc-nonexistent") assert result is None - + def test_resolve_with_empty_string_name(self): """Test _resolve handles items where name is empty string.""" shell = AWSNetShellBase() @@ -106,11 +114,11 @@ def test_resolve_with_empty_string_name(self): {"id": "vpc-abc123", "name": ""}, {"id": "vpc-def456", "name": "other-vpc"}, ] - + result = shell._resolve(items, "vpc-abc123") assert result is not None assert result["id"] == "vpc-abc123" - + def test_resolve_all_none_names(self): """Test _resolve works when ALL items have name=None.""" shell = AWSNetShellBase() @@ -119,12 +127,12 @@ def test_resolve_all_none_names(self): {"id": "vpc-def456", "name": None}, {"id": "vpc-ghi789", "name": None}, ] - + # Should still find by ID result = shell._resolve(items, "vpc-def456") assert result is not None assert result["id"] == "vpc-def456" - + # Should still find by index result = shell._resolve(items, "3") assert result is not None diff --git a/tests/test_policy_change_events.py b/tests/test_policy_change_events.py index 97dff56..fd8c1b6 100644 --- a/tests/test_policy_change_events.py +++ b/tests/test_policy_change_events.py @@ -6,8 +6,8 @@ Root cause: The sorting function used `x.get("created_at") or ""` which mixed datetime objects (from AWS API) with empty strings (for None values). """ + from datetime import datetime -import pytest class TestPolicyChangeEventsSorting: @@ -31,7 +31,7 @@ def sort_key(x): # This should not raise TypeError result = sorted(events, key=sort_key, reverse=True) - + # Newest first, None last assert [e["version"] for e in result] == [3, 1, 2] @@ -78,10 +78,10 @@ def test_original_bug_reproduction(self): {"version": 1, "created_at": datetime(2024, 1, 1)}, {"version": 2, "created_at": None}, # This becomes "" with `or ""` ] - + # Old buggy code: sorted(events, key=lambda x: (x.get("created_at") or ""), reverse=True) # This raises: TypeError: '<' not supported between instances of 'datetime.datetime' and 'str' - + # New fixed code: def sort_key(x): val = x.get("created_at") @@ -90,7 +90,7 @@ def sort_key(x): if isinstance(val, datetime): return val return datetime.min - + # Should not raise TypeError result = sorted(events, key=sort_key, reverse=True) assert len(result) == 2 diff --git a/tests/test_refresh_command.py b/tests/test_refresh_command.py index b4c18e6..8236556 100644 --- a/tests/test_refresh_command.py +++ b/tests/test_refresh_command.py @@ -95,6 +95,7 @@ def test_refresh_current_context_elb(self, shell): # Set up ELB context shell._cache = {"elbs": [{"id": "elb-1"}]} from aws_network_tools.shell.base import Context + shell.context_stack = [Context("elb", "arn:aws:...", "my-elb", {}, 1)] old_stdout = sys.stdout @@ -112,6 +113,7 @@ def test_refresh_current_context_vpc(self, shell): """Test refresh in VPC context clears VPC cache.""" shell._cache = {"vpcs": [{"id": "vpc-1"}]} from aws_network_tools.shell.base import Context + shell.context_stack = [Context("vpc", "vpc-123", "my-vpc", {}, 1)] old_stdout = sys.stdout @@ -153,7 +155,7 @@ def test_refresh_multiple_cache_keys(self, shell): shell.onecmd("refresh elbs") shell.onecmd("refresh vpcs") - output = sys.stdout.getvalue() + _output = sys.stdout.getvalue() # Captured but not asserted sys.stdout = old_stdout @@ -165,6 +167,7 @@ def test_refresh_transit_gateway_context(self, shell): """Test refresh in transit-gateway context.""" shell._cache = {"transit_gateways": [{"id": "tgw-1"}]} from aws_network_tools.shell.base import Context + shell.context_stack = [Context("transit-gateway", "tgw-123", "my-tgw", {}, 1)] old_stdout = sys.stdout @@ -182,6 +185,7 @@ def test_refresh_firewall_context(self, shell): """Test refresh in firewall context.""" shell._cache = {"firewalls": [{"id": "fw-1"}]} from aws_network_tools.shell.base import Context + shell.context_stack = [Context("firewall", "fw-123", "my-fw", {}, 1)] old_stdout = sys.stdout @@ -231,5 +235,6 @@ def test_refresh_command_available_in_all_contexts(self, shell): ] for ctx_type in context_types: - assert "refresh" in HIERARCHY[ctx_type]["commands"], \ + assert "refresh" in HIERARCHY[ctx_type]["commands"], ( f"refresh not in {ctx_type} commands" + ) diff --git a/tests/test_shell_hierarchy.py b/tests/test_shell_hierarchy.py index ca86f20..113d1d1 100644 --- a/tests/test_shell_hierarchy.py +++ b/tests/test_shell_hierarchy.py @@ -67,7 +67,11 @@ def test_tgw_context_has_required_show_options(self): def test_firewall_context_has_required_show_options(self): """Firewall context must have all required show options.""" - required = {"detail", "rule-groups", "policy"} # Issue #7: policy must be available + required = { + "detail", + "rule-groups", + "policy", + } # Issue #7: policy must be available actual = set(HIERARCHY["firewall"]["show"]) assert required.issubset(actual), f"Missing: {required - actual}" @@ -393,14 +397,26 @@ def test_firewall_policy_data_structure(self, shell): assert "arn" in policy, "Policy must have 'arn' field" # Stateless default actions must be a dict with expected keys stateless_defaults = policy.get("stateless_default_actions", {}) - assert isinstance(stateless_defaults, dict), "stateless_default_actions must be a dict" - assert "full_packets" in stateless_defaults, "stateless_default_actions must have 'full_packets'" - assert "fragmented" in stateless_defaults, "stateless_default_actions must have 'fragmented'" + assert isinstance(stateless_defaults, dict), ( + "stateless_default_actions must be a dict" + ) + assert "full_packets" in stateless_defaults, ( + "stateless_default_actions must have 'full_packets'" + ) + assert "fragmented" in stateless_defaults, ( + "stateless_default_actions must have 'fragmented'" + ) # Stateful engine options must exist - assert "stateful_engine_options" in policy, "Policy must have 'stateful_engine_options'" + assert "stateful_engine_options" in policy, ( + "Policy must have 'stateful_engine_options'" + ) engine_opts = policy.get("stateful_engine_options", {}) - assert "rule_order" in engine_opts, "stateful_engine_options must have 'rule_order'" - assert "stream_exception_policy" in engine_opts, "stateful_engine_options must have 'stream_exception_policy'" + assert "rule_order" in engine_opts, ( + "stateful_engine_options must have 'rule_order'" + ) + assert "stream_exception_policy" in engine_opts, ( + "stateful_engine_options must have 'stream_exception_policy'" + ) class TestGraphValidation: diff --git a/tests/test_show_regions.py b/tests/test_show_regions.py index 0a2bfff..d09f403 100644 --- a/tests/test_show_regions.py +++ b/tests/test_show_regions.py @@ -100,10 +100,14 @@ def test_show_regions_includes_major_regions(self, shell): sys.stdout = old_stdout major_regions = [ - "us-east-1", "us-west-2", - "eu-west-1", "eu-central-1", - "ap-southeast-1", "ap-northeast-1", - "ca-central-1", "sa-east-1" + "us-east-1", + "us-west-2", + "eu-west-1", + "eu-central-1", + "ap-southeast-1", + "ap-northeast-1", + "ca-central-1", + "sa-east-1", ] for region in major_regions: diff --git a/tests/test_tgw_issues.py b/tests/test_tgw_issues.py index a4e3b92..36aca38 100644 --- a/tests/test_tgw_issues.py +++ b/tests/test_tgw_issues.py @@ -2,12 +2,11 @@ import sys from unittest.mock import MagicMock -import pytest # Mock cmd2 before imports mock_cmd2 = MagicMock() mock_cmd2.Cmd = MagicMock -sys.modules['cmd2'] = mock_cmd2 +sys.modules["cmd2"] = mock_cmd2 class TestSetRouteTableCacheKey: @@ -51,7 +50,11 @@ def test_cache_key_matches_set_route_table_lookup(self): mock_self.ctx = MagicMock() # Explicitly create ctx mock mock_self.ctx.data = { "route_tables": [ - {"id": "rtb-test", "name": "TestRT", "routes": [{"dest": "10.0.0.0/8"}]}, + { + "id": "rtb-test", + "name": "TestRT", + "routes": [{"dest": "10.0.0.0/8"}], + }, ] } @@ -63,5 +66,7 @@ def test_cache_key_matches_set_route_table_lookup(self): cached_rts = mock_self._cache.get(lookup_key, []) # Should find the route tables, not empty list - assert cached_rts, f"Cache lookup with '{lookup_key}' returned empty - key mismatch!" + assert cached_rts, ( + f"Cache lookup with '{lookup_key}' returned empty - key mismatch!" + ) assert cached_rts[0]["id"] == "rtb-test" diff --git a/tests/test_utils/context_state_manager.py b/tests/test_utils/context_state_manager.py index 7a88a1a..be16e1a 100644 --- a/tests/test_utils/context_state_manager.py +++ b/tests/test_utils/context_state_manager.py @@ -77,7 +77,9 @@ def validate_current_state(self): # Validate each level for i, (expected_type, _) in enumerate(self.expected_stack): actual_context = self.shell.context_stack[i] - actual_type = actual_context.type # Context dataclass field is 'type', not 'ctx_type' + actual_type = ( + actual_context.type + ) # Context dataclass field is 'type', not 'ctx_type' assert actual_type == expected_type, ( f"Context type mismatch at level {i}: expected '{expected_type}', got '{actual_type}'" @@ -99,4 +101,6 @@ def reset(self): def __repr__(self) -> str: """String representation for debugging.""" - return f"ContextStateManager(expected={[ctx[0] for ctx in self.expected_stack]})" + return ( + f"ContextStateManager(expected={[ctx[0] for ctx in self.expected_stack]})" + ) diff --git a/tests/test_utils/data_format_adapter.py b/tests/test_utils/data_format_adapter.py index f0d92b5..7fd66e9 100644 --- a/tests/test_utils/data_format_adapter.py +++ b/tests/test_utils/data_format_adapter.py @@ -26,14 +26,16 @@ class DataFormatAdapter: def __init__(self): """Initialize adapter with format registry.""" self.FORMAT_REGISTRY: dict[str, Callable] = { - 'vpc': self._transform_vpc, - 'tgw': self._transform_tgw, - 'cloudwan': self._transform_cloudwan, - 'ec2': self._transform_ec2, - 'elb': self._transform_elb, + "vpc": self._transform_vpc, + "tgw": self._transform_tgw, + "cloudwan": self._transform_cloudwan, + "ec2": self._transform_ec2, + "elb": self._transform_elb, } - def transform(self, resource_type: str, fixture_data: dict[str, Any]) -> dict[str, Any]: + def transform( + self, resource_type: str, fixture_data: dict[str, Any] + ) -> dict[str, Any]: """Transform fixture data to module format. Args: @@ -75,26 +77,28 @@ def _transform_vpc(self, fixture: dict[str, Any]) -> dict[str, Any]: Fixture has: {VpcId, CidrBlock, CidrBlockAssociationSet, Tags, State} """ return { - 'id': fixture['VpcId'], - 'name': self._get_tag_value(fixture, 'Name', default=fixture['VpcId']), - 'cidr': fixture['CidrBlock'], - 'cidrs': [ - assoc['CidrBlock'] - for assoc in fixture.get('CidrBlockAssociationSet', []) + "id": fixture["VpcId"], + "name": self._get_tag_value(fixture, "Name", default=fixture["VpcId"]), + "cidr": fixture["CidrBlock"], + "cidrs": [ + assoc["CidrBlock"] + for assoc in fixture.get("CidrBlockAssociationSet", []) ], - 'region': self._derive_region_from_id(fixture['VpcId']), - 'state': fixture['State'], + "region": self._derive_region_from_id(fixture["VpcId"]), + "state": fixture["State"], } def _transform_tgw(self, fixture: dict[str, Any]) -> dict[str, Any]: """Transform Transit Gateway fixture to module format.""" return { - 'id': fixture['TransitGatewayId'], - 'name': self._get_tag_value(fixture, 'Name', default=fixture['TransitGatewayId']), - 'region': self._derive_region_from_id(fixture['TransitGatewayId']), - 'state': fixture['State'], - 'route_tables': [], - 'attachments': [], + "id": fixture["TransitGatewayId"], + "name": self._get_tag_value( + fixture, "Name", default=fixture["TransitGatewayId"] + ), + "region": self._derive_region_from_id(fixture["TransitGatewayId"]), + "state": fixture["State"], + "route_tables": [], + "attachments": [], } def _transform_cloudwan(self, fixture: dict[str, Any]) -> dict[str, Any]: @@ -103,57 +107,57 @@ def _transform_cloudwan(self, fixture: dict[str, Any]) -> dict[str, Any]: Module expects: {id, name, arn, global_network_id, state, regions, segments} """ return { - 'id': fixture['CoreNetworkId'], - 'name': self._get_tag_value(fixture, 'Name', default=fixture['CoreNetworkId']), - 'arn': fixture['CoreNetworkArn'], - 'global_network_id': fixture['GlobalNetworkId'], - 'state': fixture['State'], - 'regions': [edge['EdgeLocation'] for edge in fixture.get('Edges', [])], - 'segments': fixture.get('Segments', []), - 'route_tables': [], - 'policy': None, - 'core_networks': [], + "id": fixture["CoreNetworkId"], + "name": self._get_tag_value( + fixture, "Name", default=fixture["CoreNetworkId"] + ), + "arn": fixture["CoreNetworkArn"], + "global_network_id": fixture["GlobalNetworkId"], + "state": fixture["State"], + "regions": [edge["EdgeLocation"] for edge in fixture.get("Edges", [])], + "segments": fixture.get("Segments", []), + "route_tables": [], + "policy": None, + "core_networks": [], } def _transform_ec2(self, fixture: dict[str, Any]) -> dict[str, Any]: """Transform EC2 instance fixture to module format.""" return { - 'id': fixture['InstanceId'], - 'name': self._get_tag_value(fixture, 'Name', default=fixture['InstanceId']), - 'type': fixture['InstanceType'], - 'state': fixture['State']['Name'], - 'az': fixture['Placement']['AvailabilityZone'], - 'region': fixture['Placement']['AvailabilityZone'][:-1], # Remove AZ letter - 'vpc_id': fixture['VpcId'], - 'subnet_id': fixture['SubnetId'], - 'private_ip': fixture['PrivateIpAddress'], + "id": fixture["InstanceId"], + "name": self._get_tag_value(fixture, "Name", default=fixture["InstanceId"]), + "type": fixture["InstanceType"], + "state": fixture["State"]["Name"], + "az": fixture["Placement"]["AvailabilityZone"], + "region": fixture["Placement"]["AvailabilityZone"][:-1], # Remove AZ letter + "vpc_id": fixture["VpcId"], + "subnet_id": fixture["SubnetId"], + "private_ip": fixture["PrivateIpAddress"], } def _transform_elb(self, fixture: dict[str, Any]) -> dict[str, Any]: """Transform ELB fixture to module format (per Nova Premier).""" return { - 'arn': fixture['LoadBalancerArn'], - 'name': fixture['LoadBalancerName'], - 'type': fixture['Type'], - 'scheme': fixture['Scheme'], - 'state': fixture['State']['Code'], - 'vpc_id': fixture['VpcId'], - 'dns_name': fixture['DNSName'], - 'region': self._extract_region_from_arn(fixture['LoadBalancerArn']), + "arn": fixture["LoadBalancerArn"], + "name": fixture["LoadBalancerName"], + "type": fixture["Type"], + "scheme": fixture["Scheme"], + "state": fixture["State"]["Code"], + "vpc_id": fixture["VpcId"], + "dns_name": fixture["DNSName"], + "region": self._extract_region_from_arn(fixture["LoadBalancerArn"]), } # ========================================================================= # Helper Functions # ========================================================================= - def _get_tag_value( - self, resource: dict, key: str, default: str = '' - ) -> str: + def _get_tag_value(self, resource: dict, key: str, default: str = "") -> str: """Extract tag value from AWS Tags list.""" - tags = resource.get('Tags', []) + tags = resource.get("Tags", []) for tag in tags: - if tag.get('Key') == key: - return tag.get('Value', default) + if tag.get("Key") == key: + return tag.get("Value", default) return default def _derive_region_from_id(self, resource_id: str) -> str: @@ -165,21 +169,21 @@ def _derive_region_from_id(self, resource_id: str) -> str: NOTE (Nova Premier feedback): Phase 2 should use explicit region field in fixtures instead of this pattern matching approach. """ - if 'prod' in resource_id: - return 'eu-west-1' - elif 'stag' in resource_id: - return 'us-east-1' - elif 'dev' in resource_id: - return 'ap-southeast-2' - return 'us-east-1' # Default + if "prod" in resource_id: + return "eu-west-1" + elif "stag" in resource_id: + return "us-east-1" + elif "dev" in resource_id: + return "ap-southeast-2" + return "us-east-1" # Default def _extract_region_from_arn(self, arn: str) -> str: """Extract region from AWS ARN. ARN format: arn:aws:service:region:account:resource """ - parts = arn.split(':') - return parts[3] if len(parts) > 3 else 'us-east-1' + parts = arn.split(":") + return parts[3] if len(parts) > 3 else "us-east-1" # ========================================================================= # Validation Functions (per Nova Premier) @@ -187,23 +191,32 @@ def _extract_region_from_arn(self, arn: str) -> str: def validate_vpc_id(self, vpc_id: str) -> bool: """Validate VPC ID format: vpc-[0-9a-f]{17}.""" - return bool(re.match(r'^vpc-[0-9a-f]{17}$', vpc_id)) + return bool(re.match(r"^vpc-[0-9a-f]{17}$", vpc_id)) def validate_arn(self, arn: str) -> bool: """Validate AWS ARN format (supports standard, GovCloud, China). Per Nova Premier: Comprehensive ARN validation including all partitions. """ - return bool(re.match( - r'^arn:(aws|aws-us-gov|aws-cn):[a-zA-Z0-9-]+:[a-zA-Z0-9-]*:\d{12}:[a-zA-Z0-9-_/:.]+$', - arn - )) + return bool( + re.match( + r"^arn:(aws|aws-us-gov|aws-cn):[a-zA-Z0-9-]+:[a-zA-Z0-9-]*:\d{12}:[a-zA-Z0-9-_/:.]+$", + arn, + ) + ) def validate_region(self, region: str) -> bool: """Validate region is valid AWS region.""" AWS_REGIONS = [ - 'us-east-1', 'us-east-2', 'us-west-1', 'us-west-2', - 'eu-west-1', 'eu-west-2', 'eu-central-1', - 'ap-southeast-1', 'ap-southeast-2', 'ap-northeast-1', + "us-east-1", + "us-east-2", + "us-west-1", + "us-west-2", + "eu-west-1", + "eu-west-2", + "eu-central-1", + "ap-southeast-1", + "ap-southeast-2", + "ap-northeast-1", ] return region in AWS_REGIONS diff --git a/tests/test_utils/test_context_state_manager.py b/tests/test_utils/test_context_state_manager.py index 753b068..b7666ee 100644 --- a/tests/test_utils/test_context_state_manager.py +++ b/tests/test_utils/test_context_state_manager.py @@ -42,19 +42,19 @@ def manager(self): def test_push_single_context(self, manager): """Binary: Push single context to expected stack.""" - manager.push_context('vpc', {'vpc_id': 'vpc-123'}) + manager.push_context("vpc", {"vpc_id": "vpc-123"}) assert len(manager.expected_stack) == 1 - assert manager.expected_stack[0][0] == 'vpc' + assert manager.expected_stack[0][0] == "vpc" def test_push_multiple_contexts(self, manager): """Binary: Push multiple contexts maintains order.""" - manager.push_context('global-network', {'gn_id': 'gn-123'}) - manager.push_context('core-network', {'cn_id': 'cn-456'}) + manager.push_context("global-network", {"gn_id": "gn-123"}) + manager.push_context("core-network", {"cn_id": "cn-456"}) assert len(manager.expected_stack) == 2 - assert manager.expected_stack[0][0] == 'global-network' - assert manager.expected_stack[1][0] == 'core-network' + assert manager.expected_stack[0][0] == "global-network" + assert manager.expected_stack[1][0] == "core-network" class TestValidateCurrentState: @@ -75,12 +75,13 @@ def test_validate_single_context_match(self, manager): """Binary: Single context validates successfully.""" # Manually set shell context from aws_network_tools.core import Context + manager.shell.context_stack = [ - Context(type='vpc', ref='1', name='production-vpc', data={}) + Context(type="vpc", ref="1", name="production-vpc", data={}) ] # Set expected - manager.push_context('vpc', {}) + manager.push_context("vpc", {}) # Validate manager.validate_current_state() # Should not raise @@ -89,8 +90,9 @@ def test_validate_depth_mismatch(self, manager): """Binary: Depth mismatch raises AssertionError.""" # Shell has context, expected is empty from aws_network_tools.core import Context + manager.shell.context_stack = [ - Context(type='vpc', ref='1', name='test', data={}) + Context(type="vpc", ref="1", name="test", data={}) ] with pytest.raises(AssertionError, match="stack depth"): @@ -99,10 +101,11 @@ def test_validate_depth_mismatch(self, manager): def test_validate_type_mismatch(self, manager): """Binary: Context type mismatch raises AssertionError.""" from aws_network_tools.core import Context + manager.shell.context_stack = [ - Context(type='vpc', ref='1', name='test', data={}) + Context(type="vpc", ref="1", name="test", data={}) ] - manager.push_context('tgw', {}) # Expected tgw, got vpc + manager.push_context("tgw", {}) # Expected tgw, got vpc with pytest.raises(AssertionError, match="[Cc]ontext type"): manager.validate_current_state() @@ -118,12 +121,12 @@ def manager(self): def test_pop_single_context(self, manager): """Binary: Pop removes last context from expected stack.""" - manager.push_context('vpc', {}) - manager.push_context('subnet', {}) + manager.push_context("vpc", {}) + manager.push_context("subnet", {}) manager.pop_context() assert len(manager.expected_stack) == 1 - assert manager.expected_stack[0][0] == 'vpc' + assert manager.expected_stack[0][0] == "vpc" def test_pop_empty_stack_raises(self, manager): """Binary: Pop on empty stack raises error.""" @@ -141,8 +144,8 @@ def manager(self): def test_reset_clears_expected_stack(self, manager): """Binary: Reset clears expected stack.""" - manager.push_context('vpc', {}) - manager.push_context('subnet', {}) + manager.push_context("vpc", {}) + manager.push_context("subnet", {}) manager.reset() assert len(manager.expected_stack) == 0 @@ -162,8 +165,8 @@ def test_get_empty_context_type(self, manager): def test_get_current_context_type(self, manager): """Binary: Returns top of expected stack.""" - manager.push_context('vpc', {}) - assert manager.get_current_context_type() == 'vpc' + manager.push_context("vpc", {}) + assert manager.get_current_context_type() == "vpc" - manager.push_context('subnet', {}) - assert manager.get_current_context_type() == 'subnet' + manager.push_context("subnet", {}) + assert manager.get_current_context_type() == "subnet" diff --git a/tests/test_utils/test_data_format_adapter.py b/tests/test_utils/test_data_format_adapter.py index 0630d2c..6c26f08 100644 --- a/tests/test_utils/test_data_format_adapter.py +++ b/tests/test_utils/test_data_format_adapter.py @@ -10,7 +10,6 @@ """ import pytest -import re from tests.test_utils.data_format_adapter import DataFormatAdapter from tests.fixtures import ( VPC_FIXTURES, @@ -28,14 +27,14 @@ def test_adapter_initialization(self): """Binary: Adapter initializes with format registry.""" adapter = DataFormatAdapter() assert adapter is not None - assert hasattr(adapter, 'transform') + assert hasattr(adapter, "transform") def test_format_registry_exists(self): """Binary: Format registry contains AWS resource types.""" adapter = DataFormatAdapter() - assert hasattr(adapter, 'FORMAT_REGISTRY') - assert 'vpc' in adapter.FORMAT_REGISTRY - assert 'tgw' in adapter.FORMAT_REGISTRY + assert hasattr(adapter, "FORMAT_REGISTRY") + assert "vpc" in adapter.FORMAT_REGISTRY + assert "tgw" in adapter.FORMAT_REGISTRY class TestVPCTransformation: @@ -48,44 +47,44 @@ def adapter(self): def test_vpc_id_mapping(self, adapter): """Binary: VPC ID transforms from VpcId to id.""" vpc_fixture = VPC_FIXTURES["vpc-0prod1234567890ab"] - transformed = adapter.transform('vpc', vpc_fixture) + transformed = adapter.transform("vpc", vpc_fixture) - assert transformed['id'] == vpc_fixture['VpcId'] - assert transformed['id'] == "vpc-0prod1234567890ab" + assert transformed["id"] == vpc_fixture["VpcId"] + assert transformed["id"] == "vpc-0prod1234567890ab" def test_vpc_cidr_extraction(self, adapter): """Binary: VPC CIDR blocks extracted correctly.""" vpc_fixture = VPC_FIXTURES["vpc-0prod1234567890ab"] - transformed = adapter.transform('vpc', vpc_fixture) + transformed = adapter.transform("vpc", vpc_fixture) - assert transformed['cidr'] == vpc_fixture['CidrBlock'] - assert 'cidrs' in transformed - assert len(transformed['cidrs']) >= 1 + assert transformed["cidr"] == vpc_fixture["CidrBlock"] + assert "cidrs" in transformed + assert len(transformed["cidrs"]) >= 1 def test_vpc_name_from_tags(self, adapter): """Binary: VPC name extracted from Tags.""" vpc_fixture = VPC_FIXTURES["vpc-0prod1234567890ab"] - transformed = adapter.transform('vpc', vpc_fixture) + transformed = adapter.transform("vpc", vpc_fixture) expected_name = next( - t['Value'] for t in vpc_fixture['Tags'] if t['Key'] == 'Name' + t["Value"] for t in vpc_fixture["Tags"] if t["Key"] == "Name" ) - assert transformed['name'] == expected_name + assert transformed["name"] == expected_name def test_vpc_region_derived(self, adapter): """Binary: VPC region derived from ID pattern or metadata.""" vpc_fixture = VPC_FIXTURES["vpc-0prod1234567890ab"] - transformed = adapter.transform('vpc', vpc_fixture) + transformed = adapter.transform("vpc", vpc_fixture) - assert 'region' in transformed - assert transformed['region'] in ['eu-west-1', 'us-east-1', 'ap-southeast-2'] + assert "region" in transformed + assert transformed["region"] in ["eu-west-1", "us-east-1", "ap-southeast-2"] def test_vpc_state_preserved(self, adapter): """Binary: VPC state copied from fixture.""" vpc_fixture = VPC_FIXTURES["vpc-0prod1234567890ab"] - transformed = adapter.transform('vpc', vpc_fixture) + transformed = adapter.transform("vpc", vpc_fixture) - assert transformed['state'] == vpc_fixture['State'] + assert transformed["state"] == vpc_fixture["State"] class TestTGWTransformation: @@ -98,12 +97,12 @@ def adapter(self): def test_tgw_basic_fields(self, adapter): """Binary: TGW ID, name, state transformed.""" tgw_fixture = TGW_FIXTURES["tgw-0prod12345678901"] - transformed = adapter.transform('tgw', tgw_fixture) + transformed = adapter.transform("tgw", tgw_fixture) - assert transformed['id'] == tgw_fixture['TransitGatewayId'] - assert 'name' in transformed - assert transformed['state'] == tgw_fixture['State'] - assert 'region' in transformed + assert transformed["id"] == tgw_fixture["TransitGatewayId"] + assert "name" in transformed + assert transformed["state"] == tgw_fixture["State"] + assert "region" in transformed class TestCloudWANTransformation: @@ -116,19 +115,19 @@ def adapter(self): def test_core_network_fields(self, adapter): """Binary: Core network has all required fields.""" cn_fixture = CLOUDWAN_FIXTURES["core-network-0global123"] - transformed = adapter.transform('cloudwan', cn_fixture) + transformed = adapter.transform("cloudwan", cn_fixture) - assert transformed['id'] == cn_fixture['CoreNetworkId'] - assert transformed['global_network_id'] == cn_fixture['GlobalNetworkId'] - assert 'segments' in transformed - assert 'arn' in transformed + assert transformed["id"] == cn_fixture["CoreNetworkId"] + assert transformed["global_network_id"] == cn_fixture["GlobalNetworkId"] + assert "segments" in transformed + assert "arn" in transformed def test_core_network_segments_preserved(self, adapter): """Binary: Segments list preserved in transformation.""" cn_fixture = CLOUDWAN_FIXTURES["core-network-0global123"] - transformed = adapter.transform('cloudwan', cn_fixture) + transformed = adapter.transform("cloudwan", cn_fixture) - assert len(transformed['segments']) == len(cn_fixture['Segments']) + assert len(transformed["segments"]) == len(cn_fixture["Segments"]) class TestAWSValidation: @@ -141,33 +140,38 @@ def adapter(self): def test_vpc_id_format_validation(self, adapter): """Binary: VPC ID matches AWS pattern vpc-XXXXXXXXX.""" vpc_fixture = VPC_FIXTURES["vpc-0prod1234567890ab"] - transformed = adapter.transform('vpc', vpc_fixture) + transformed = adapter.transform("vpc", vpc_fixture) # Fixture IDs use descriptive names (prod/stag/dev), not pure hex # Real AWS validation would be: r'^vpc-[0-9a-f]{17}$' # For fixtures, validate prefix and length - assert transformed['id'].startswith('vpc-') - assert len(transformed['id']) == 21 # vpc- + 17 chars + assert transformed["id"].startswith("vpc-") + assert len(transformed["id"]) == 21 # vpc- + 17 chars def test_arn_format_validation(self, adapter): """Binary: ARNs match AWS pattern (standard, GovCloud, China).""" cn_fixture = CLOUDWAN_FIXTURES["core-network-0global123"] - transformed = adapter.transform('cloudwan', cn_fixture) + transformed = adapter.transform("cloudwan", cn_fixture) - if 'arn' in transformed: + if "arn" in transformed: # Use adapter's validate_arn with comprehensive pattern - assert adapter.validate_arn(transformed['arn']) + assert adapter.validate_arn(transformed["arn"]) def test_region_consistency_validation(self, adapter): """Binary: Region values are valid AWS regions.""" vpc_fixture = VPC_FIXTURES["vpc-0prod1234567890ab"] - transformed = adapter.transform('vpc', vpc_fixture) + transformed = adapter.transform("vpc", vpc_fixture) AWS_REGIONS = [ - 'us-east-1', 'us-west-2', 'eu-west-1', 'ap-southeast-2', - 'us-east-2', 'eu-central-1', 'ap-northeast-1' + "us-east-1", + "us-west-2", + "eu-west-1", + "ap-southeast-2", + "us-east-2", + "eu-central-1", + "ap-northeast-1", ] - assert transformed.get('region') in AWS_REGIONS + assert transformed.get("region") in AWS_REGIONS class TestEC2Transformation: @@ -180,11 +184,11 @@ def adapter(self): def test_ec2_instance_basic_fields(self, adapter): """Binary: EC2 instance has ID, type, state.""" ec2_fixture = list(EC2_INSTANCE_FIXTURES.values())[0] - transformed = adapter.transform('ec2', ec2_fixture) + transformed = adapter.transform("ec2", ec2_fixture) - assert transformed['id'] == ec2_fixture['InstanceId'] - assert transformed['type'] == ec2_fixture['InstanceType'] - assert transformed['state'] == ec2_fixture['State']['Name'] + assert transformed["id"] == ec2_fixture["InstanceId"] + assert transformed["type"] == ec2_fixture["InstanceType"] + assert transformed["state"] == ec2_fixture["State"]["Name"] class TestELBTransformation: @@ -197,18 +201,18 @@ def adapter(self): def test_elb_arn_and_name(self, adapter): """Binary: ELB ARN and name extracted.""" elb_fixture = list(ELB_FIXTURES.values())[0] - transformed = adapter.transform('elb', elb_fixture) + transformed = adapter.transform("elb", elb_fixture) - assert transformed['arn'] == elb_fixture['LoadBalancerArn'] - assert transformed['name'] == elb_fixture['LoadBalancerName'] + assert transformed["arn"] == elb_fixture["LoadBalancerArn"] + assert transformed["name"] == elb_fixture["LoadBalancerName"] def test_elb_type_and_scheme(self, adapter): """Binary: ELB type and scheme preserved.""" elb_fixture = list(ELB_FIXTURES.values())[0] - transformed = adapter.transform('elb', elb_fixture) + transformed = adapter.transform("elb", elb_fixture) - assert transformed['type'] == elb_fixture['Type'] - assert transformed['scheme'] == elb_fixture['Scheme'] + assert transformed["type"] == elb_fixture["Type"] + assert transformed["scheme"] == elb_fixture["Scheme"] class TestTransformBatch: @@ -220,15 +224,15 @@ def adapter(self): def test_transform_all_vpcs(self, adapter): """Binary: All VPCs transform without errors.""" - results = adapter.transform_batch('vpc', list(VPC_FIXTURES.values())) + results = adapter.transform_batch("vpc", list(VPC_FIXTURES.values())) assert len(results) == len(VPC_FIXTURES) - assert all('id' in vpc for vpc in results) - assert all('name' in vpc for vpc in results) + assert all("id" in vpc for vpc in results) + assert all("name" in vpc for vpc in results) def test_transform_all_tgws(self, adapter): """Binary: All TGWs transform without errors.""" - results = adapter.transform_batch('tgw', list(TGW_FIXTURES.values())) + results = adapter.transform_batch("tgw", list(TGW_FIXTURES.values())) assert len(results) == len(TGW_FIXTURES) - assert all('id' in tgw for tgw in results) + assert all("id" in tgw for tgw in results) diff --git a/tests/test_validators.py b/tests/test_validators.py index 37751cd..2bd6e3a 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,6 +1,5 @@ """Tests for input validators.""" -import pytest from aws_network_tools.core.validators import ( validate_regions, validate_profile, @@ -20,14 +19,18 @@ def test_valid_single_region(self): def test_valid_multiple_regions_comma_separated(self): """Test multiple regions with comma separation.""" - is_valid, regions, error = validate_regions("us-east-1,eu-west-1,ap-southeast-1") + is_valid, regions, error = validate_regions( + "us-east-1,eu-west-1,ap-southeast-1" + ) assert is_valid assert regions == ["us-east-1", "eu-west-1", "ap-southeast-1"] assert error is None def test_valid_regions_with_spaces_after_commas(self): """Test regions with spaces after commas (should be trimmed).""" - is_valid, regions, error = validate_regions("us-east-1, eu-west-1, ap-southeast-1") + is_valid, regions, error = validate_regions( + "us-east-1, eu-west-1, ap-southeast-1" + ) assert is_valid assert regions == ["us-east-1", "eu-west-1", "ap-southeast-1"] assert error is None @@ -81,8 +84,14 @@ def test_suggestions_for_typos(self): def test_all_major_regions_valid(self): """Test all major AWS regions are recognized.""" major_regions = [ - "us-east-1", "us-west-2", "eu-west-1", "eu-central-1", - "ap-southeast-1", "ap-northeast-1", "ca-central-1", "sa-east-1" + "us-east-1", + "us-west-2", + "eu-west-1", + "eu-central-1", + "ap-southeast-1", + "ap-northeast-1", + "ca-central-1", + "sa-east-1", ] for region in major_regions: is_valid, regions, error = validate_regions(region) diff --git a/tests/test_vpn_tunnels.py b/tests/test_vpn_tunnels.py index aa551bd..c5fc04e 100644 --- a/tests/test_vpn_tunnels.py +++ b/tests/test_vpn_tunnels.py @@ -148,9 +148,9 @@ def test_show_detail_includes_tunnels(): total = len(results) passed = sum(results) - print(f"\n{'='*50}") + print(f"\n{'=' * 50}") print(f"Results: {passed}/{total} tests passed") print(f"BINARY: {'✅ PASS' if all_passed else '❌ FAIL'}") - print(f"{'='*50}") + print(f"{'=' * 50}") sys.exit(0 if all_passed else 1) diff --git a/tests/unit/test_ec2_eni_filtering.py b/tests/unit/test_ec2_eni_filtering.py index 731b314..49dd850 100644 --- a/tests/unit/test_ec2_eni_filtering.py +++ b/tests/unit/test_ec2_eni_filtering.py @@ -4,8 +4,7 @@ Following TDD: This test FAILS initially, proving bug exists. """ -import pytest -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import patch, MagicMock from aws_network_tools.modules.ec2 import EC2Client from tests.fixtures.ec2 import EC2_INSTANCE_FIXTURES, ENI_FIXTURES @@ -18,7 +17,7 @@ def test_get_instance_detail_filters_enis(self): instance_id = "i-0prodweb1a123456789" region = "eu-west-1" - with patch('boto3.Session') as mock_session_class: + with patch("boto3.Session") as mock_session_class: mock_session = MagicMock() mock_client = MagicMock() mock_session.client.return_value = mock_client @@ -26,23 +25,25 @@ def test_get_instance_detail_filters_enis(self): # Mock describe_instances mock_client.describe_instances.return_value = { - 'Reservations': [{'Instances': [EC2_INSTANCE_FIXTURES[instance_id]]}] + "Reservations": [{"Instances": [EC2_INSTANCE_FIXTURES[instance_id]]}] } # Mock describe_network_interfaces with filter matching def mock_describe_enis(Filters=None, **kwargs): if Filters: for f in Filters: - if f['Name'] == 'attachment.instance-id': - target = f['Values'][0] + if f["Name"] == "attachment.instance-id": + target = f["Values"][0] return { - 'NetworkInterfaces': [ - eni for eni in ENI_FIXTURES.values() - if eni.get('Attachment', {}).get('InstanceId') == target + "NetworkInterfaces": [ + eni + for eni in ENI_FIXTURES.values() + if eni.get("Attachment", {}).get("InstanceId") + == target ] } # Bug: Returns all ENIs - return {'NetworkInterfaces': list(ENI_FIXTURES.values())} + return {"NetworkInterfaces": list(ENI_FIXTURES.values())} mock_client.describe_network_interfaces = mock_describe_enis @@ -51,9 +52,9 @@ def mock_describe_enis(Filters=None, **kwargs): result = client.get_instance_detail(instance_id, region) # Binary assertion: Should be ≤2 ENIs - enis = result.get('enis', []) + enis = result.get("enis", []) assert len(enis) <= 2, f"Expected 1-2 ENIs, got {len(enis)}" # Verify correct ENI if enis: - assert enis[0]['id'] == 'eni-0prodweb1a1234567' + assert enis[0]["id"] == "eni-0prodweb1a1234567" diff --git a/tests/unit/test_elb_commands.py b/tests/unit/test_elb_commands.py index 3f90f8d..43f2046 100644 --- a/tests/unit/test_elb_commands.py +++ b/tests/unit/test_elb_commands.py @@ -8,11 +8,10 @@ Following TDD: These tests FAIL initially, proving commands don't exist. """ -import pytest -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import patch, MagicMock from aws_network_tools.shell.handlers.elb import ELBHandlersMixin from aws_network_tools.modules.elb import ELBClient -from tests.fixtures.elb import ELB_FIXTURES, LISTENER_FIXTURES, TARGET_GROUP_FIXTURES +from tests.fixtures.elb import ELB_FIXTURES, LISTENER_FIXTURES class TestELBHandlerMethods: @@ -21,17 +20,17 @@ class TestELBHandlerMethods: def test_show_listeners_method_exists(self): """Binary: _show_listeners handler method must exist.""" mixin = ELBHandlersMixin() - assert hasattr(mixin, '_show_listeners'), "Missing _show_listeners handler" + assert hasattr(mixin, "_show_listeners"), "Missing _show_listeners handler" def test_show_targets_method_exists(self): """Binary: _show_targets handler method must exist.""" mixin = ELBHandlersMixin() - assert hasattr(mixin, '_show_targets'), "Missing _show_targets handler" + assert hasattr(mixin, "_show_targets"), "Missing _show_targets handler" def test_show_health_method_exists(self): """Binary: _show_health handler method must exist.""" mixin = ELBHandlersMixin() - assert hasattr(mixin, '_show_health'), "Missing _show_health handler" + assert hasattr(mixin, "_show_health"), "Missing _show_health handler" class TestELBClientMethods: @@ -40,17 +39,17 @@ class TestELBClientMethods: def test_get_listeners_method_exists(self): """Binary: get_listeners client method must exist.""" client = ELBClient() - assert hasattr(client, 'get_listeners'), "Missing get_listeners method" + assert hasattr(client, "get_listeners"), "Missing get_listeners method" def test_get_target_groups_method_exists(self): """Binary: get_target_groups client method must exist.""" client = ELBClient() - assert hasattr(client, 'get_target_groups'), "Missing get_target_groups method" + assert hasattr(client, "get_target_groups"), "Missing get_target_groups method" def test_get_target_health_method_exists(self): """Binary: get_target_health client method must exist.""" client = ELBClient() - assert hasattr(client, 'get_target_health'), "Missing get_target_health method" + assert hasattr(client, "get_target_health"), "Missing get_target_health method" class TestELBClientFunctionality: @@ -60,7 +59,7 @@ def test_get_listeners_returns_list(self): """Binary: get_listeners returns list of listener dicts.""" elb_arn = list(ELB_FIXTURES.keys())[0] - with patch('boto3.Session') as mock_session_class: + with patch("boto3.Session") as mock_session_class: mock_session = MagicMock() mock_client = MagicMock() mock_session.client.return_value = mock_client @@ -68,11 +67,11 @@ def test_get_listeners_returns_list(self): # Mock AWS API response mock_client.describe_listeners.return_value = { - 'Listeners': list(LISTENER_FIXTURES.values()) + "Listeners": list(LISTENER_FIXTURES.values()) } client = ELBClient(session=mock_session) - result = client.get_listeners(elb_arn, 'us-east-1') + result = client.get_listeners(elb_arn, "us-east-1") # Binary assertions assert isinstance(result, list), "get_listeners must return list" @@ -89,9 +88,9 @@ def test_context_commands_includes_elb_commands(self): module = ELBModule() ctx_commands = module.context_commands - assert 'elb' in ctx_commands, "Must define elb context commands" + assert "elb" in ctx_commands, "Must define elb context commands" - elb_commands = ctx_commands['elb'] - assert 'listeners' in elb_commands, "Missing 'listeners' command" - assert 'targets' in elb_commands, "Missing 'targets' command" - assert 'health' in elb_commands, "Missing 'health' command" + elb_commands = ctx_commands["elb"] + assert "listeners" in elb_commands, "Missing 'listeners' command" + assert "targets" in elb_commands, "Missing 'targets' command" + assert "health" in elb_commands, "Missing 'health' command" diff --git a/themes/catppuccin-latte.json b/themes/catppuccin-latte.json index 7d7afb5..1211ce2 100644 --- a/themes/catppuccin-latte.json +++ b/themes/catppuccin-latte.json @@ -1,20 +1,20 @@ { - "name": "catppuccin-latte", - "description": "Catppuccin Latte - Light theme with soft, muted colors", "author": "Catppuccin Theme", - "upstream": "https://github.com/catppuccin/catppuccin", "colors": { - "root": "#4c4f69", - "global-network": "#8839ef", "core-network": "#ea76cb", - "route-table": "#04a5e5", - "vpc": "#40a02b", - "transit-gateway": "#fe640b", - "firewall": "#d20f39", - "elb": "#df8e1d", - "vpn": "#6c6f85", "ec2-instance": "#8839ef", + "elb": "#df8e1d", + "firewall": "#d20f39", + "global-network": "#8839ef", "prompt_separator": "#6c6f85", - "prompt_text": "#4c4f69" - } -} \ No newline at end of file + "prompt_text": "#4c4f69", + "root": "#4c4f69", + "route-table": "#04a5e5", + "transit-gateway": "#fe640b", + "vpc": "#40a02b", + "vpn": "#6c6f85" + }, + "description": "Catppuccin Latte - Light theme with soft, muted colors", + "name": "catppuccin-latte", + "upstream": "https://github.com/catppuccin/catppuccin" +} diff --git a/themes/catppuccin-macchiato.json b/themes/catppuccin-macchiato.json index 0744e44..f69f5ca 100644 --- a/themes/catppuccin-macchiato.json +++ b/themes/catppuccin-macchiato.json @@ -1,20 +1,20 @@ { - "name": "catppuccin-macchiato", - "description": "Catppuccin Macchiato - Dark theme with vibrant, medium contrast", "author": "Catppuccin Theme", - "upstream": "https://github.com/catppuccin/catppuccin", "colors": { - "root": "#cad3f5", - "global-network": "#c6a0f6", "core-network": "#f5bde6", - "route-table": "#7dc4e4", - "vpc": "#a6da95", - "transit-gateway": "#f5a97f", - "firewall": "#ed8796", - "elb": "#eed49f", - "vpn": "#939ab7", "ec2-instance": "#c6a0f6", + "elb": "#eed49f", + "firewall": "#ed8796", + "global-network": "#c6a0f6", "prompt_separator": "#939ab7", - "prompt_text": "#cad3f5" - } -} \ No newline at end of file + "prompt_text": "#cad3f5", + "root": "#cad3f5", + "route-table": "#7dc4e4", + "transit-gateway": "#f5a97f", + "vpc": "#a6da95", + "vpn": "#939ab7" + }, + "description": "Catppuccin Macchiato - Dark theme with vibrant, medium contrast", + "name": "catppuccin-macchiato", + "upstream": "https://github.com/catppuccin/catppuccin" +} diff --git a/themes/catppuccin-mocha-vibrant.json b/themes/catppuccin-mocha-vibrant.json index 2c2e02b..cb560e0 100644 --- a/themes/catppuccin-mocha-vibrant.json +++ b/themes/catppuccin-mocha-vibrant.json @@ -1,23 +1,23 @@ { - "name": "catppuccin-mocha-vibrant", - "description": "Catppuccin Mocha with more vibrant, saturated colors for dark terminals", "author": "aws-network-shell", "colors": { - "root": "#89b4fa", - "global-network": "#cba6f7", "core-network": "#f5c2e7", - "route-table": "#94e2d5", - "vpc": "#a6e3a1", - "transit-gateway": "#fab387", - "firewall": "#f38ba8", - "elb": "#f9e2af", - "vpn": "#b4befe", "ec2-instance": "#cba6f7", + "elb": "#f9e2af", + "firewall": "#f38ba8", + "global-network": "#cba6f7", "prompt_separator": "#9399b2", "prompt_text": "#89b4fa", - "table_header": "#cba6f7", + "root": "#89b4fa", + "route-table": "#94e2d5", "table_border": "#6c7086", + "table_header": "#cba6f7", + "table_row_even": "#313244", "table_row_odd": "#45475a", - "table_row_even": "#313244" - } + "transit-gateway": "#fab387", + "vpc": "#a6e3a1", + "vpn": "#b4befe" + }, + "description": "Catppuccin Mocha with more vibrant, saturated colors for dark terminals", + "name": "catppuccin-mocha-vibrant" } diff --git a/themes/catppuccin-mocha.json b/themes/catppuccin-mocha.json index 03ea718..eef216a 100644 --- a/themes/catppuccin-mocha.json +++ b/themes/catppuccin-mocha.json @@ -1,20 +1,20 @@ { - "name": "catppuccin-mocha", - "description": "Catppuccin Mocha - Dark theme with vibrant colors for prompts", "author": "Catppuccin Theme", - "upstream": "https://github.com/catppuccin/catppuccin", "colors": { - "root": "#89b4fa", - "global-network": "#cba6f7", "core-network": "#f5c2e7", - "route-table": "#94e2d5", - "vpc": "#a6e3a1", - "transit-gateway": "#fab387", - "firewall": "#f38ba8", - "elb": "#f9e2af", - "vpn": "#b4befe", "ec2-instance": "#cba6f7", + "elb": "#f9e2af", + "firewall": "#f38ba8", + "global-network": "#cba6f7", "prompt_separator": "#9399b2", - "prompt_text": "#89b4fa" - } -} \ No newline at end of file + "prompt_text": "#89b4fa", + "root": "#89b4fa", + "route-table": "#94e2d5", + "transit-gateway": "#fab387", + "vpc": "#a6e3a1", + "vpn": "#b4befe" + }, + "description": "Catppuccin Mocha - Dark theme with vibrant colors for prompts", + "name": "catppuccin-mocha", + "upstream": "https://github.com/catppuccin/catppuccin" +} diff --git a/themes/dracula.json b/themes/dracula.json index 052967d..4142a56 100644 --- a/themes/dracula.json +++ b/themes/dracula.json @@ -1,20 +1,20 @@ { - "name": "dracula", - "description": "Dracula theme for AWS Network Shell", "author": "Dracula Theme", - "upstream": "https://draculatheme.com/", "colors": { - "root": "#f8f8f2", - "global-network": "#bd93f9", "core-network": "#ff79c6", - "route-table": "#8be9fd", - "vpc": "#50fa7b", - "transit-gateway": "#ffb86c", - "firewall": "#ff5555", - "elb": "#f1fa8c", - "vpn": "#6272a4", "ec2-instance": "#bd93f9", + "elb": "#f1fa8c", + "firewall": "#ff5555", + "global-network": "#bd93f9", "prompt_separator": "#6272a4", - "prompt_text": "#f8f8f2" - } -} \ No newline at end of file + "prompt_text": "#f8f8f2", + "root": "#f8f8f2", + "route-table": "#8be9fd", + "transit-gateway": "#ffb86c", + "vpc": "#50fa7b", + "vpn": "#6272a4" + }, + "description": "Dracula theme for AWS Network Shell", + "name": "dracula", + "upstream": "https://draculatheme.com/" +} diff --git a/validate_issues.sh b/validate_issues.sh index 8fc82c9..2ef13d3 100755 --- a/validate_issues.sh +++ b/validate_issues.sh @@ -57,4 +57,4 @@ with mock_aws(): " echo "" -echo "Summary: Issues that still need fixing are marked with ✗" \ No newline at end of file +echo "Summary: Issues that still need fixing are marked with ✗" From 24a9c2593c04c57940e1f8d05a545268acdabfb6 Mon Sep 17 00:00:00 2001 From: d-padmanabhan <88682350+d-padmanabhan@users.noreply.github.com> Date: Mon, 5 Jan 2026 08:41:41 -0500 Subject: [PATCH 2/4] chore(repo): add repository scaffolding and configuration files - Add .gitattributes for line ending normalization and binary handling - Add .editorconfig for consistent formatting across editors - Add .markdownlint.yaml for markdown linting rules - Add .commitlintrc.yaml for conventional commit enforcement - Add .github/dependabot.yml for automated dependency updates - Add .github/CODEOWNERS for code ownership assignments - Add .github/PULL_REQUEST_TEMPLATE.md for standardized PR descriptions - Update .pre-commit-config.yaml with named hooks, shellcheck, markdownlint - Update .gitignore with pyright cache and coverage file patterns - Auto-fix markdown files (trailing whitespace, newlines) --- .commitlintrc.yaml | 130 ++++++++++++++++++ .editorconfig | 78 +++++++++++ .gitattributes | 84 ++++++++++++ .github/CODEOWNERS | 53 ++++++++ .github/PULL_REQUEST_TEMPLATE.md | 77 +++++++++++ .github/dependabot.yml | 68 ++++++++++ .gitignore | 11 ++ .markdownlint.yaml | 163 ++++++++++++++++++++++ .pre-commit-config.yaml | 211 ++++++++++++++++++++--------- README.md | 28 +++- docs/ARCHITECTURE.md | 55 +++++++- docs/README.md | 8 +- docs/command-hierarchy-graph.md | 103 ++++++++------ docs/command-hierarchy-split.md | 3 + scripts/AUTOMATION_README.md | 7 + scripts/README.md | 5 + tests/README.md | 12 +- tests/fixtures/README.md | 26 +++- tests/test_command_graph/README.md | 8 ++ 19 files changed, 1014 insertions(+), 116 deletions(-) create mode 100644 .commitlintrc.yaml create mode 100644 .editorconfig create mode 100644 .gitattributes create mode 100644 .github/CODEOWNERS create mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 .github/dependabot.yml create mode 100644 .markdownlint.yaml diff --git a/.commitlintrc.yaml b/.commitlintrc.yaml new file mode 100644 index 0000000..d199191 --- /dev/null +++ b/.commitlintrc.yaml @@ -0,0 +1,130 @@ +--- +# Commitlint configuration +# https://commitlint.js.org/ +# +# Enforces Conventional Commits format: +# (): +# +# Examples: +# feat(vpc): add subnet discovery command +# fix(shell): correct context prompt rendering +# docs(readme): update installation instructions +# refactor(core): extract base handler class +# test(vpn): add tunnel status tests + +extends: + - "@commitlint/config-conventional" + +rules: + # Type must be one of the allowed values + type-enum: + - 2 + - always + - - feat # New feature + - fix # Bug fix + - docs # Documentation changes + - style # Code style (formatting, semicolons, etc.) + - refactor # Code refactoring (no feature or fix) + - perf # Performance improvement + - test # Add or update tests + - build # Build system or dependencies + - ci # CI/CD configuration + - chore # Maintenance tasks + - revert # Revert a commit + + # Type must be lowercase + type-case: + - 2 + - always + - lower-case + + # Type is required + type-empty: + - 2 + - never + + # Scope should be lowercase + scope-case: + - 2 + - always + - lower-case + + # Scope is optional but encouraged + scope-empty: + - 0 + - never + + # Common scopes for this project + scope-enum: + - 1 + - always + - - vpc # VPC-related commands + - tgw # Transit Gateway commands + - firewall # Network Firewall commands + - vpn # VPN commands + - elb # Load balancer commands + - ec2 # EC2 instance commands + - cloudwan # Cloud WAN commands + - shell # Shell framework + - core # Core utilities + - cli # CLI interface + - handlers # Command handlers + - models # Data models + - modules # AWS service modules + - traceroute # Traceroute functionality + - cache # Caching system + - graph # Command graph + - config # Configuration + - tests # Test infrastructure + - docs # Documentation + - deps # Dependencies + - ci # CI/CD + + # Subject must not be empty + subject-empty: + - 2 + - never + + # Subject should be lowercase + subject-case: + - 2 + - always + - - lower-case + - sentence-case + + # Subject should not end with period + subject-full-stop: + - 2 + - never + - "." + + # Header max length (type + scope + subject) + header-max-length: + - 2 + - always + - 72 + + # Body max line length + body-max-line-length: + - 2 + - always + - 100 + + # Footer max line length + footer-max-line-length: + - 2 + - always + - 100 + + # Body should be separated from subject by blank line + body-leading-blank: + - 2 + - always + + # Footer should be separated from body by blank line + footer-leading-blank: + - 2 + - always + +# Help message displayed on failure +helpUrl: "https://www.conventionalcommits.org/" diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..04d30ce --- /dev/null +++ b/.editorconfig @@ -0,0 +1,78 @@ +# EditorConfig - https://editorconfig.org +# Helps maintain consistent coding styles across different editors + +root = true + +# Default settings for all files +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true +indent_style = space +indent_size = 4 + +# Python files +[*.py] +indent_size = 4 +max_line_length = 120 + +# Python type stubs +[*.pyi] +indent_size = 4 +max_line_length = 120 + +# YAML files +[*.{yaml,yml}] +indent_size = 2 + +# JSON files +[*.json] +indent_size = 2 + +# TOML files +[*.toml] +indent_size = 4 + +# Markdown files +[*.md] +trim_trailing_whitespace = false +max_line_length = 120 + +# Shell scripts +[*.{sh,bash,zsh}] +indent_size = 2 +shell_variant = bash + +# Makefile (requires tabs) +[Makefile] +indent_style = tab +indent_size = 4 + +[makefile] +indent_style = tab +indent_size = 4 + +# Git configuration +[.git*] +indent_size = 4 + +# GitHub Actions +[.github/workflows/*.{yaml,yml}] +indent_size = 2 + +# Pre-commit config +[.pre-commit-config.yaml] +indent_size = 2 + +# Documentation +[docs/**] +indent_size = 2 + +# Test files +[tests/**/*.py] +indent_size = 4 + +# Schema files +[schemas/*.json] +indent_size = 2 diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..ec88afe --- /dev/null +++ b/.gitattributes @@ -0,0 +1,84 @@ +# Auto detect text files and perform LF normalization +* text=auto + +# Source code - always LF +*.py text eol=lf diff=python +*.pyx text eol=lf diff=python +*.pyi text eol=lf diff=python + +# Shell scripts - always LF +*.sh text eol=lf diff=bash +*.bash text eol=lf diff=bash +*.zsh text eol=lf + +# Configuration files +*.json text eol=lf diff=json +*.yaml text eol=lf +*.yml text eol=lf +*.toml text eol=lf +*.cfg text eol=lf +*.ini text eol=lf +*.conf text eol=lf + +# Documentation +*.md text eol=lf diff=markdown +*.txt text eol=lf +*.rst text eol=lf +LICENSE text eol=lf +CHANGELOG text eol=lf +AUTHORS text eol=lf + +# Web assets +*.html text eol=lf diff=html +*.css text eol=lf +*.js text eol=lf +*.ts text eol=lf + +# Data files +*.csv text eol=lf +*.xml text eol=lf + +# Binary files - do not modify +*.png binary +*.jpg binary +*.jpeg binary +*.gif binary +*.ico binary +*.svg binary +*.woff binary +*.woff2 binary +*.ttf binary +*.eot binary +*.pdf binary +*.zip binary +*.tar.gz binary +*.tgz binary +*.gz binary + +# Database files +*.db binary +*.sqlite binary +*.sqlite3 binary + +# Generated files - do not diff +uv.lock -diff linguist-generated=true +*.lock -diff linguist-generated=true + +# Secrets baseline - always LF +.secrets.baseline text eol=lf + +# Exclude from language statistics +docs/* linguist-documentation +tests/* linguist-vendored=false +schemas/*.json linguist-generated=true + +# Export-ignore (not included in archives) +.github export-ignore +.gitattributes export-ignore +.gitignore export-ignore +.pre-commit-config.yaml export-ignore +.editorconfig export-ignore +.markdownlint.yaml export-ignore +.commitlintrc.yaml export-ignore +.secrets.baseline export-ignore +tests/ export-ignore diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..6e19b93 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,53 @@ +# CODEOWNERS file +# https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners +# +# Each line is a file pattern followed by one or more owners. +# Order matters: later patterns override earlier ones. +# Owners are notified when matching files are changed. + +# Default owner for everything in the repository +* @dpadmanabhan + +# Core framework and shell implementation +/src/aws_network_tools/core/ @dpadmanabhan +/src/aws_network_tools/shell/ @dpadmanabhan + +# AWS service modules +/src/aws_network_tools/modules/ @dpadmanabhan + +# Data models +/src/aws_network_tools/models/ @dpadmanabhan + +# Traceroute functionality +/src/aws_network_tools/traceroute/ @dpadmanabhan + +# CLI interfaces +/src/aws_network_tools/cli/ @dpadmanabhan +/src/aws_network_tools/cli.py @dpadmanabhan + +# Configuration +/src/aws_network_tools/config/ @dpadmanabhan + +# Tests +/tests/ @dpadmanabhan + +# Documentation +/docs/ @dpadmanabhan +*.md @dpadmanabhan + +# Build and packaging +pyproject.toml @dpadmanabhan +uv.lock @dpadmanabhan + +# CI/CD and automation +/.github/ @dpadmanabhan +/.pre-commit-config.yaml @dpadmanabhan + +# Scripts and utilities +/scripts/ @dpadmanabhan + +# Schemas +/schemas/ @dpadmanabhan + +# Themes +/themes/ @dpadmanabhan diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..dfd1331 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,77 @@ +# Pull Request + +## Summary + + + +## Type of Change + + + +- [ ] 🐛 Bug fix (non-breaking change that fixes an issue) +- [ ] ✨ New feature (non-breaking change that adds functionality) +- [ ] 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] 📚 Documentation update +- [ ] 🔧 Configuration change +- [ ] ♻️ Refactoring (no functional changes) +- [ ] 🧪 Test update +- [ ] 🏗️ Build/CI change + +## Changes + + + +- +- +- + +## Related Issues + + + +## Testing + + + +- [ ] Unit tests pass (`pytest tests/`) +- [ ] Integration tests pass (if applicable) +- [ ] Manual testing performed +- [ ] Pre-commit hooks pass (`pre-commit run --all-files`) + +### Test Commands Run + +```bash +# Example: +pytest tests/ -v +pre-commit run --all-files +``` + +## Screenshots/Recordings + + + +## Checklist + + + +- [ ] My code follows the project's coding standards +- [ ] I have performed a self-review of my code +- [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] I have updated the documentation (if applicable) +- [ ] My changes generate no new warnings +- [ ] I have added tests that prove my fix is effective or my feature works +- [ ] New and existing unit tests pass locally with my changes +- [ ] Any dependent changes have been merged and published + +## Security Considerations + + + +- [ ] No hardcoded secrets or credentials +- [ ] No sensitive data logged or exposed +- [ ] Input validation added where appropriate +- [ ] AWS credentials handled securely + +## Additional Notes + + diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..f66bd2f --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,68 @@ +--- +# Dependabot configuration +# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 + +updates: + # Python dependencies (pip/uv) + - package-ecosystem: pip + directory: "/" + schedule: + interval: weekly + day: monday + time: "09:00" + timezone: America/New_York + open-pull-requests-limit: 10 + commit-message: + prefix: "chore(deps)" + labels: + - dependencies + - python + reviewers: + - dpadmanabhan + groups: + # Group minor and patch updates together + python-minor: + update-types: + - minor + - patch + patterns: + - "*" + exclude-patterns: + - boto3 + - botocore + # Keep AWS SDK updates separate + aws-sdk: + patterns: + - boto3 + - botocore + - aiobotocore + + # GitHub Actions + - package-ecosystem: github-actions + directory: "/" + schedule: + interval: weekly + day: monday + time: "09:00" + timezone: America/New_York + open-pull-requests-limit: 5 + commit-message: + prefix: "ci(deps)" + labels: + - dependencies + - github-actions + groups: + # Group all GHA updates together + actions: + patterns: + - "*" + +# Private registries can be configured here if needed: +# registries: +# python-private: +# type: python-index +# url: https://pypi.example.com/simple +# username: ${{ secrets.PYPI_USERNAME }} +# password: ${{ secrets.PYPI_PASSWORD }} diff --git a/.gitignore b/.gitignore index a80a4cd..c1c2ab9 100644 --- a/.gitignore +++ b/.gitignore @@ -19,17 +19,21 @@ wheels/ *.egg-info/ .installed.cfg *.egg +MANIFEST # Virtual environments .venv/ venv/ ENV/ env/ +.uv/ # Environment files .env .env.local .env.*.local +.env.development +.env.production # IDE .idea/ @@ -44,15 +48,22 @@ env/ # Testing .pytest_cache/ .coverage +.coverage.* htmlcov/ .tox/ .nox/ +coverage.xml +*.cover +*.py,cover # Type checking / Linting .mypy_cache/ .ruff_cache/ .dmypy.json dmypy.json +.pyright/ +.pytype/ +pyrightconfig.json # Logs *.log diff --git a/.markdownlint.yaml b/.markdownlint.yaml new file mode 100644 index 0000000..52d619e --- /dev/null +++ b/.markdownlint.yaml @@ -0,0 +1,163 @@ +--- +# Markdownlint configuration +# https://github.com/DavidAnson/markdownlint/blob/main/doc/Rules.md + +# Default: all rules enabled +default: true + +# MD001 - Heading levels should only increment by one level at a time +MD001: true + +# MD003 - Heading style (atx style: # Heading) +MD003: + style: atx + +# MD004 - Unordered list style (dash) +MD004: + style: dash + +# MD007 - Unordered list indentation (2 spaces) +MD007: + indent: 2 + start_indented: false + +# MD009 - No trailing spaces (except in code blocks for line breaks) +MD009: + br_spaces: 2 + strict: false + +# MD010 - No hard tabs +MD010: + code_blocks: false + +# MD012 - No multiple consecutive blank lines +MD012: + maximum: 1 + +# MD013 - Line length (disabled - too strict for documentation) +MD013: false + +# MD022 - Headings should be surrounded by blank lines +MD022: + lines_above: 1 + lines_below: 1 + +# MD024 - No duplicate headings (disabled - common in large docs) +MD024: false + +# MD025 - Single H1 per document +MD025: true + +# MD026 - No trailing punctuation in headings +MD026: + punctuation: ".,;:!。,;:!" + +# MD029 - Ordered list item prefix +# Disabled: existing files use sequential numbering (1, 2, 3) +MD029: false + +# MD030 - Spaces after list markers +MD030: + ul_single: 1 + ol_single: 1 + ul_multi: 1 + ol_multi: 1 + +# MD031 - Fenced code blocks should be surrounded by blank lines +MD031: + list_items: true + +# MD032 - Lists should be surrounded by blank lines +MD032: true + +# MD033 - Allow inline HTML (common in READMEs) +MD033: + allowed_elements: + - br + - details + - summary + - kbd + - sup + - sub + - img + - a + - p + - div + - span + - table + - thead + - tbody + - tr + - th + - td + +# MD034 - No bare URLs +MD034: true + +# MD035 - Horizontal rule style (consistent) +MD035: + style: "---" + +# MD036 - No emphasis used instead of heading (disabled - common in READMEs) +MD036: false + +# MD037 - No spaces inside emphasis markers +MD037: true + +# MD038 - No spaces inside code span elements +MD038: true + +# MD039 - No spaces inside link text +MD039: true + +# MD040 - Fenced code blocks should have a language specified +# Disabled: many existing blocks without language +MD040: false + +# MD041 - First line should be a top-level heading +MD041: + level: 1 + front_matter_title: "^\\s*title\\s*[:=]" + +# MD044 - Proper names should have correct capitalization +MD044: + names: + - AWS + - Python + - GitHub + - CLI + - API + code_blocks: false + +# MD045 - Images should have alternate text +MD045: true + +# MD046 - Code block style (fenced) +MD046: + style: fenced + +# MD047 - Files should end with a single newline character +MD047: true + +# MD048 - Code fence style (backtick) +MD048: + style: backtick + +# MD049 - Emphasis style (asterisk) +MD049: + style: asterisk + +# MD050 - Strong style (asterisk) +MD050: + style: asterisk + +# MD051 - Link fragments should be valid +MD051: true + +# MD052 - Reference links and images should use a label that is defined +MD052: true + +# MD053 - Link and image reference definitions should be needed +MD053: + ignored_definitions: + - "//" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c59d694..d6e4602 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,62 +1,151 @@ +--- +# Pre-commit hooks configuration +# Run: pre-commit run --all-files +# Update: pre-commit autoupdate + repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v6.0.0 - hooks: - - id: check-added-large-files - - id: check-case-conflict - - id: check-executables-have-shebangs - - id: check-illegal-windows-names - - id: check-json - - id: check-merge-conflict - - id: check-shebang-scripts-are-executable - - id: check-symlinks - - id: check-toml - - id: check-xml - - id: check-yaml - - id: end-of-file-fixer - - id: debug-statements - - id: destroyed-symlinks - - id: detect-private-key - - id: detect-aws-credentials - args: - - --allow-missing-credentials - - id: forbid-submodules - - id: pretty-format-json - exclude: ^docusaurus\/package-lock.json$ - args: - - --autofix - - id: trailing-whitespace -- repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.7 - hooks: - - id: ruff - args: - - --fix - - id: ruff-format -- repo: https://github.com/Yelp/detect-secrets - rev: v1.5.0 - hooks: - - id: detect-secrets - args: - - --baseline - - .secrets.baseline -- repo: local - hooks: - - id: pyright - name: pyright - pass_filenames: true - language: system - entry: bash -c 'for x in "$@"; do (cd `dirname $x`; pwd; uv run --frozen --all-extras - --dev pyright --stats;); done;' -- - stages: - - pre-push - files: (src|samples)\/.*\/pyproject.toml - - id: pytest - name: pytest - pass_filenames: true - language: system - entry: bash -c 'for x in "$@"; do (cd `dirname $x`; pwd; uv run --frozen pytest - --cov --cov-branch --cov-report=term-missing;); done;' -- - stages: - - pre-push - files: src\/.*\/pyproject.toml + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: check-added-large-files + name: Check for large files + args: ['--maxkb=1000'] + - id: check-case-conflict + name: Check for case conflicts + - id: check-executables-have-shebangs + name: Check executables have shebangs + - id: check-illegal-windows-names + name: Check for illegal Windows filenames + - id: check-json + name: Check JSON syntax + - id: check-merge-conflict + name: Check for merge conflicts + - id: check-shebang-scripts-are-executable + name: Check shebang scripts are executable + - id: check-symlinks + name: Check for broken symlinks + - id: check-toml + name: Check TOML syntax + - id: check-xml + name: Check XML syntax + - id: check-yaml + name: Check YAML syntax + - id: end-of-file-fixer + name: Fix end of file + - id: debug-statements + name: Check for debug statements + - id: destroyed-symlinks + name: Check for destroyed symlinks + - id: detect-private-key + name: Detect private keys + - id: detect-aws-credentials + name: Detect AWS credentials + args: + - --allow-missing-credentials + - id: forbid-submodules + name: Forbid submodules + - id: pretty-format-json + name: Pretty format JSON + exclude: ^docusaurus\/package-lock.json$ + args: + - --autofix + - id: trailing-whitespace + name: Trim trailing whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.0 + hooks: + - id: ruff + name: Ruff linter + args: + - --fix + - id: ruff-format + name: Ruff formatter + + - repo: https://github.com/Yelp/detect-secrets + rev: v1.5.0 + hooks: + - id: detect-secrets + name: Detect secrets + args: + - --baseline + - .secrets.baseline + + - repo: local + hooks: + - id: shellcheck + name: ShellCheck + language: system + entry: shellcheck + args: + - --severity=warning + - --shell=bash + types: [shell] + files: \.(sh|bash)$ + + - repo: https://github.com/igorshubovych/markdownlint-cli + rev: v0.43.0 + hooks: + - id: markdownlint + name: Markdown lint + args: + - --config + - .markdownlint.yaml + - --fix + + - repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook + rev: v9.21.0 + hooks: + - id: commitlint + name: Commit message lint + stages: [commit-msg] + additional_dependencies: ['@commitlint/config-conventional'] + + - repo: local + hooks: + - id: pyright + name: Pyright type checker + pass_filenames: true + language: system + entry: >- + bash -c 'for x in "$@"; do + (cd `dirname $x`; pwd; + uv run --frozen --all-extras --dev pyright --stats;); + done;' -- + stages: + - pre-push + files: (src|samples)\/.*\/pyproject.toml + + - id: pytest + name: Pytest unit tests + pass_filenames: true + language: system + entry: >- + bash -c 'for x in "$@"; do + (cd `dirname $x`; pwd; + uv run --frozen pytest --cov --cov-branch + --cov-report=term-missing;); + done;' -- + stages: + - pre-push + files: src\/.*\/pyproject.toml + + - id: pre-commit-autoupdate-reminder + name: Pre-commit autoupdate reminder + language: system + entry: >- + bash -c 'echo "Run: pre-commit autoupdate"' + always_run: false + pass_filenames: false + stages: + - manual + +# To run autoupdate manually: +# pre-commit autoupdate +# +# To run all hooks: +# pre-commit run --all-files +# +# To install hooks: +# pre-commit install +# pre-commit install --hook-type commit-msg diff --git a/README.md b/README.md index a7e7d06..921ad61 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ Paths to 'find-prefix': ``` ### Available Graph Operations + - `show graph` - Display command tree structure - `show graph stats` - Show command statistics - `show graph validate` - Verify all handlers implemented @@ -125,6 +126,7 @@ aws-net>fi:1>ru:2> show rule-group ``` ### Firewall Commands Summary + - **Contexts**: firewall → rule-group (2-level hierarchy) - **Commands**: show firewall, show rule-groups, show policy, set rule-group, show rule-group - **Display**: Complete rule details including ports, protocols, actions, and Suricata rules @@ -175,6 +177,7 @@ aws-net>vp:1> show tunnels ``` ### VPN Commands Summary + - **Context**: vpn (1-level hierarchy) - **Commands**: show detail, show tunnels - **Display**: IPSec tunnel status with outside IPs, BGP route counts, status messages @@ -247,6 +250,7 @@ pytest tests/ -v ``` ### Test Coverage + - **Root commands**: 42 commands - **Context commands**: 35+ commands - **Total coverage**: 77+ commands @@ -255,6 +259,7 @@ pytest tests/ -v ## 📖 Usage Examples ### Basic Commands + ```bash aws-net> show vpcs aws-net> show global-networks @@ -263,6 +268,7 @@ aws-net> show detail ``` ### Context Navigation + ```bash # Enter VPC context aws-net> set vpc 1 @@ -278,6 +284,7 @@ tgw> exit ``` ### AWS Operations + ```bash # Trace between IPs aws-net> trace 192.168.1.10 10.0.0.5 @@ -290,6 +297,7 @@ aws-net> find_null_routes ``` ### Cache Management + ```bash # Scenario: ELBs haven't finished provisioning yet aws-net> show elbs @@ -320,6 +328,7 @@ Cleared 5 cache entries ## 📊 Commands by Category ### Cache Management (2) + - `clear_cache` - Clear all cached data permanently - `refresh [target|all]` - Refresh cached data selectively - `refresh` - Refresh current context (e.g., in ELB context, clears ELB cache) @@ -328,6 +337,7 @@ Cleared 5 cache entries - `refresh all` - Clear all caches globally ### Show Commands (34) + - Network Resources: `vpcs`, `transit_gateways`, `firewalls`, `elbs`, `vpns` - Compute: `ec2-instances`, `enis` - Connectivity: `dx-connections`, `peering-connections` @@ -338,10 +348,12 @@ Cleared 5 cache entries - System: `config`, `cache`, `routing-cache` ### Set Commands (8 Contexts) + - `vpc`, `transit-gateway`, `global-network`, `core-network` - `firewall`, `ec2-instance`, `elb`, `vpn` ### Action Commands (9) + - `write `, `trace `, `find_ip ` - `find_prefix `, `find_null_routes` - `reachability`, `populate_cache`, `clear_cache` @@ -354,6 +366,7 @@ Cleared 5 cache entries **Purpose**: Pre-fetch ALL topology data across all modules for comprehensive analysis **What it caches**: + - VPCs, subnets, route tables, security groups, NACLs - Transit Gateways, attachments, peerings, route tables - Cloud WAN (global networks, core networks, segments, attachments) @@ -382,11 +395,13 @@ Cache populated **Purpose**: Build specialized cache of ONLY routing data for fast route lookups and analysis **What it caches**: + - VPC route table entries (all route tables across all VPCs) - Transit Gateway route table entries (all TGW route tables) - Cloud WAN route table entries (by core network, segment, and region) **Enables Commands**: + - `find_prefix ` - Find which route tables contain a prefix - `find_null_routes` - Find blackhole/null routes across all routing domains - `show routing-cache ` - View cached routes with filtering @@ -403,6 +418,7 @@ Building routing cache... ``` **View Cached Routes**: + ```bash # Summary aws-net> show routing-cache @@ -428,12 +444,14 @@ aws-net> show routing-cache all # Everything (comprehensive view) | **When to use** | Before exploration/demos | Before routing troubleshooting | **Recommendation**: + - Use `populate_cache` for general exploration and comprehensive analysis - Use `create_routing_cache` specifically for routing troubleshooting and prefix searches ## 🔧 Configuration Default configuration in `pyproject.toml`: + - **Timeout**: 120 seconds per command - **Regions**: All enabled regions - **Cache**: Enabled by default @@ -457,6 +475,7 @@ pytest tests/ -v ## 📦 Dependencies Core dependencies: + - **boto3**: AWS SDK - **rich**: Terminal formatting - **cmd2**: Shell framework @@ -478,9 +497,10 @@ MIT License - see LICENSE file for details ## 📝 Changelog ### 2024-12-08 + - ✅ VPN tunnel inspection: show tunnels displays VgwTelemetry data with UP/DOWN status - ✅ VPN detail view: show detail includes tunnel summary with outside IPs and BGP routes -- ✅ Debug logging: aws-net-runner --debug flag with comprehensive logging to /tmp/ +- ✅ Debug logging: AWS-net-runner --debug flag with comprehensive logging to /tmp/ - ✅ Network Firewall enhancements: rule-group context with detailed rule inspection - ✅ Enhanced firewall commands: show firewall, show rule-groups with indexes - ✅ STATELESS rules: Complete display with ports, protocols, actions @@ -489,21 +509,23 @@ MIT License - see LICENSE file for details - ✅ Persistent routing cache with SQLite (save/load commands) - ✅ Enhanced routing cache display: vpc, transit-gateway, cloud-wan filters - ✅ Terminal width detection for proper Rich table rendering -- ✅ aws-net-runner tool for programmatic shell execution +- ✅ AWS-net-runner tool for programmatic shell execution ### 2024-12-05 + - ✅ ELB commands implementation (listeners, targets, health) - ✅ VPN context commands (detail, tunnels) - ✅ Firewall policy command - ✅ Core-network commands registration fix - ✅ Direct resource selection without show command - ✅ Automated issue resolution workflow -- ✅ Consolidated CLI to aws-net-shell only +- ✅ Consolidated CLI to AWS-net-shell only - ✅ Multi-level context prompt fix - ✅ Comprehensive testing framework with pexpect integration - ✅ Graph-based command testing ### 2024-12-02 + - ✅ Comprehensive command graph with context navigation - ✅ Dynamic command discovery (78+ commands) - ✅ Command graph Mermaid diagrams diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index bf7f2b9..68ac0be 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -164,6 +164,7 @@ src/aws_network_tools/ #### `base.py` - Foundation Classes **Classes**: + - `BaseClient` - Boto3 session management with custom config - Handles AWS credentials (profile or default) - Standardized retry/timeout configuration @@ -201,6 +202,7 @@ graph LR ``` **Classes**: + - `Cache(namespace)` - Namespace-isolated cache - `get(ignore_expiry, current_account)` - Retrieve with validation - `set(data, ttl_seconds, account_id)` - Store with metadata @@ -208,6 +210,7 @@ graph LR - `get_info()` - Cache metadata (age, TTL, expiry status) **Features**: + - TTL-based expiration (default 15min, configurable) - Account safety (auto-clear on account switch) - Namespace isolation (separate cache per service) @@ -218,6 +221,7 @@ graph LR #### `decorators.py` - Command Decorators **Functions**: + - `@requires_context(ctx_type)` - Ensure command runs in correct context - `@cached(key, ttl)` - Auto-cache function results - `@spinner(message)` - Show progress spinner during execution @@ -225,6 +229,7 @@ graph LR #### `display.py` & `renderer.py` - UI Components **Classes**: + - `BaseDisplay` - Abstract display interface - Defines `show_detail()`, `render_table()` methods - Used by service-specific display classes @@ -237,12 +242,14 @@ graph LR #### `spinner.py` - Progress Feedback **Functions**: + - `run_with_spinner(fn, message)` - Execute with progress indicator - Handles exceptions and displays errors gracefully #### `ip_resolver.py` - Network Utilities **Functions**: + - IP address parsing and validation - CIDR manipulation and comparison - Subnet calculations @@ -293,6 +300,7 @@ classDiagram ``` **Purpose**: + - Type-safe data structures - Automatic validation on construction - Backward compatibility via `to_dict()` @@ -414,6 +422,7 @@ sequenceDiagram #### `base.py` - AWSNetShellBase **Core Responsibilities**: + 1. **Context Stack Management** - `context_stack: list[Context]` - Navigation history - `_enter(ctx_type, ref, name, data, index)` - Push context @@ -441,6 +450,7 @@ sequenceDiagram #### `main.py` - AWSNetShell **Mixin Composition**: + ```python class AWSNetShell( RootHandlersMixin, # Root-level: show, set, trace, find_ip @@ -457,6 +467,7 @@ class AWSNetShell( ``` **Key Methods**: + - `_cached(key, fetch_fn, msg)` - Cache wrapper with spinner - `_emit_json_or_table(data, render_fn)` - Format-aware output - `do_show(args)` - Route show commands to handlers @@ -469,6 +480,7 @@ class AWSNetShell( Each handler mixin provides commands for a specific AWS service or context. **Pattern**: + ```python class ServiceHandlersMixin: def _show_[resource](self, args): @@ -634,10 +646,12 @@ graph TB ``` **Two-Level Caching**: + 1. **Memory Cache** (`self._cache`) - Session-scoped, cleared by refresh 2. **File Cache** (`Cache` class) - Persistent, TTL + account-aware **Cache Keys**: + - `vpcs`, `transit_gateways`, `firewalls`, `elb`, `vpns`, `ec2_instances` - `global_networks`, `core_networks`, `enis` - Namespaced by service type @@ -757,6 +771,7 @@ sequenceDiagram #### Step 1: Create Module File `modules/my_service.py`: + ```python from ..core.base import BaseClient, ModuleInterface, BaseDisplay from ..models.base import AWSResource @@ -827,6 +842,7 @@ class MyServiceModule(ModuleInterface): #### Step 2: Add Handler Mixin `shell/handlers/my_service.py`: + ```python from rich.console import Console @@ -899,6 +915,7 @@ class MyServiceHandlersMixin: #### Step 3: Update Hierarchy `shell/base.py` - Add to `HIERARCHY`: + ```python HIERARCHY = { None: { @@ -917,6 +934,7 @@ HIERARCHY = { #### Step 4: Register in Main Shell `shell/main.py`: + ```python from .handlers import ( ..., @@ -933,6 +951,7 @@ class AWSNetShell( #### Step 5: Add Tests `tests/test_my_service.py`: + ```python def test_show_my_resources(shell): """Test showing resources""" @@ -1073,6 +1092,7 @@ graph TB **Example**: Add `show performance` to ELB context 1. **Update Hierarchy** (`shell/base.py`): + ```python "elb": { "show": ["detail", "listeners", "targets", "health", "performance"], @@ -1081,6 +1101,7 @@ graph TB ``` 2. **Add Handler** (`shell/handlers/elb.py`): + ```python def _show_performance(self, _): """Show ELB performance metrics""" @@ -1094,6 +1115,7 @@ def _show_performance(self, _): ``` 3. **Add Test** (`tests/test_elb_handler.py`): + ```python def test_show_performance(shell): # Setup ELB context @@ -1106,11 +1128,13 @@ def test_show_performance(shell): **Example**: Add CSV export 1. **Add to Base** (`shell/base.py`): + ```python "set": [..., "output-format"], ``` 2. **Update Handler** (`shell/main.py`): + ```python def _emit_json_or_table(self, data, render_table_fn): if self.output_format == "json": @@ -1132,6 +1156,7 @@ def _emit_json_or_table(self, data, render_table_fn): ### ModuleInterface (Abstract Base Class) **Contract**: + ```python class ModuleInterface(ABC): @property @@ -1166,6 +1191,7 @@ class ModuleInterface(ABC): ### BaseClient Interface **Contract**: + ```python class BaseClient: def __init__(self, profile: Optional[str] = None, session: Optional[boto3.Session] = None): @@ -1176,6 +1202,7 @@ class BaseClient: ``` **Guarantees**: + - Automatic retry on throttling (10 attempts, exponential backoff) - 5s connect timeout, 20s read timeout - User agent tracking for API metrics @@ -1184,6 +1211,7 @@ class BaseClient: ### BaseDisplay Interface **Contract**: + ```python class BaseDisplay: def __init__(self, console: Console): @@ -1242,6 +1270,7 @@ graph TB ``` **Test Coverage** (12/09/2025): + - **Total Tests**: 200+ across 40+ test files - **Shell Tests**: Context navigation, command validation - **Handler Tests**: Each service handler validated @@ -1269,12 +1298,14 @@ graph TB ### Theme System **Available Themes**: + - `catppuccin-mocha` (default) - Dark theme with pastel colors - `catppuccin-latte` - Light theme - `catppuccin-macchiato` - Mid-tone theme - `dracula` - Purple-focused dark theme **Theme Structure**: + ```json { "prompt_text": "white", @@ -1290,6 +1321,7 @@ graph TB ``` **Customization** (`shell/base.py:236-335`): + - Prompt styles: "short" (compact) vs "long" (multi-line) - Index display: show/hide selection numbers - Max length: truncate long names @@ -1302,6 +1334,7 @@ graph TB ### Concurrent API Calls **Pattern** (`modules/*.py`): + ```python def discover_multi_region(self, regions: List[str]) -> List[dict]: from concurrent.futures import ThreadPoolExecutor, as_completed @@ -1322,17 +1355,20 @@ def discover_multi_region(self, regions: List[str]) -> List[dict]: ### Smart Caching Strategy **Level 1 - Memory Cache** (`shell.main._cache`): + - Stores list views (show vpcs, show elbs) - Cleared by `refresh` command - Session-scoped only **Level 2 - File Cache** (`core/cache.py`): + - Stores expensive API results - TTL-based expiration (15 min default) - Account-aware (auto-clear on profile switch) - Survives shell restarts **Level 3 - Routing Cache** (`shell/utilities.py`): + - Pre-computed route tables across all services - Used by `find_prefix` at root level - Database-backed for complex queries @@ -1340,6 +1376,7 @@ def discover_multi_region(self, regions: List[str]) -> List[dict]: ### Lazy Loading **Pattern**: + ```python def _show_detail(self, _): # Fetch full details only when user enters context @@ -1356,11 +1393,13 @@ def _show_detail(self, _): ### AWS Credentials **Priority Order**: + 1. `--profile` flag → Use specific AWS profile 2. `AWS_PROFILE` env var -3. Default credentials chain (IAM role, env vars, ~/.aws/credentials) +3. Default credentials chain (IAM role, env vars, ~/.AWS/credentials) **Account Safety**: + - Cache stores `account_id` with each entry - Automatic cache invalidation on account switch - Prevents cross-account data leakage @@ -1368,10 +1407,12 @@ def _show_detail(self, _): ### Sensitive Data Handling **Not Logged**: + - AWS credentials or temporary tokens - Resource content (S3 objects, secrets) **Logged** (debug mode): + - API call parameters (resource IDs, filters) - Response metadata (status codes, timing) - Command execution trace @@ -1379,6 +1420,7 @@ def _show_detail(self, _): ### Input Validation **Models Layer** (`models/*.py`): + - Pydantic validation on all AWS responses - CIDR format validation - Resource ID format checking @@ -1391,12 +1433,14 @@ def _show_detail(self, _): ### Debug Mode **Enable** via runner: + ```bash aws-net-runner --debug "show vpcs" "set vpc 1" # Logs to: /tmp/aws_net_runner_debug_.log ``` **Log Contents**: + - Command execution timeline - AWS API calls with parameters - Cache hits/misses @@ -1406,6 +1450,7 @@ aws-net-runner --debug "show vpcs" "set vpc 1" ### Graph Validation **Check command hierarchy integrity**: + ```bash aws-net> show graph validate ✓ Graph is valid - all handlers implemented @@ -1417,14 +1462,17 @@ aws-net> show graph validate ### Common Issues **Issue**: Commands not appearing in context + - **Cause**: Missing in `HIERARCHY` dict - **Fix**: Add to context's "commands" list **Issue**: Cache not clearing + - **Cause**: Using wrong cache key name - **Fix**: Use `refresh all` or check `cache_mappings` in `do_refresh()` **Issue**: AWS API throttling + - **Cause**: Too many concurrent requests - **Fix**: Reduce `AWS_NET_MAX_WORKERS` env var (default 10) @@ -1458,6 +1506,7 @@ aws-net> show graph validate ### Complete Module List **AWS Service Modules** (23 total): + 1. `cloudwan.py` - Cloud WAN & Global Networks 2. `vpc.py` - VPCs, Subnets, Route Tables 3. `tgw.py` - Transit Gateways, Attachments @@ -1554,7 +1603,7 @@ cache_mappings = { - **Context**: Current CLI scope (vpc, transit-gateway, etc.) - **Context Stack**: Navigation history (breadcrumb trail) -- **Handler**: Shell command implementation (do_show, _set_vpc, etc.) +- **Handler**: Shell command implementation (do_show,_set_vpc, etc.) - **Module**: AWS service integration (CloudWANClient, VPCClient) - **Mixin**: Composable class adding commands to shell - **Cache Key**: String identifier for cached data ("vpcs", "elb", etc.) @@ -1564,5 +1613,5 @@ cache_mappings = { --- **Generated**: 2025-12-09 -**Repository**: https://github.com/[your-org]/aws-network-shell +**Repository**: **Documentation**: See `docs/` for command hierarchy and testing guides diff --git a/docs/README.md b/docs/README.md index 6f05231..126f90d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -3,6 +3,7 @@ ## Architecture & Codemap **ARCHITECTURE.md** - Comprehensive technical documentation (15KB) + - System architecture with Mermaid diagrams - Module breakdown and interactions - Data flow and core workflows @@ -41,12 +42,14 @@ aws-net> export-graph [filename] ### Static Documentation **command-hierarchy-split.md** - Multi-diagram Mermaid format (11KB) + - Multiple small, focused diagrams - One diagram per context (VPC, Transit Gateway, Firewall, etc.) - Left-to-right layout for readability - **Most readable option for static viewing** **command-hierarchy-graph.md** - Single unified graph (12KB) + - Complete command hierarchy in one diagram - Shows all context relationships - Top-down tree layout @@ -61,6 +64,7 @@ typora docs/command-hierarchy-graph.md ## Testing Documentation See `tests/README.md` for: + - Test framework architecture - Running tests - Writing new tests @@ -69,7 +73,8 @@ See `tests/README.md` for: ## Scripts Documentation See `scripts/README.md` and `scripts/AUTOMATION_README.md` for: -- aws-net-runner usage + +- AWS-net-runner usage - Workflow automation - Issue resolution automation - Shell runner API @@ -77,6 +82,7 @@ See `scripts/README.md` and `scripts/AUTOMATION_README.md` for: ## Main Documentation See root `README.md` for: + - Installation and setup - Command categories - Usage examples diff --git a/docs/command-hierarchy-graph.md b/docs/command-hierarchy-graph.md index 64d10ee..7da7e61 100644 --- a/docs/command-hierarchy-graph.md +++ b/docs/command-hierarchy-graph.md @@ -171,107 +171,120 @@ Commands available immediately after starting the shell: **Entry Command**: ✓ `set global-network` → enters `global-network` context **Show Commands**: - - ✓ `show core-networks` - - ✓ `show detail` + +- ✓ `show core-networks` +- ✓ `show detail` ### Core Network Context **Entry Command**: ✓ `set core-network` → enters `core-network` context **Show Commands**: - - ✓ `show blackhole-routes` - - ✓ `show connect-attachments` - - ✓ `show connect-peers` - - ✓ `show detail` - - ✓ `show policy` - - ✓ `show policy-change-events` - - ✓ `show rib` - - ✓ `show route-tables` - - ✓ `show routes` - - ✓ `show segments` + +- ✓ `show blackhole-routes` +- ✓ `show connect-attachments` +- ✓ `show connect-peers` +- ✓ `show detail` +- ✓ `show policy` +- ✓ `show policy-change-events` +- ✓ `show rib` +- ✓ `show route-tables` +- ✓ `show routes` +- ✓ `show segments` **Action Commands**: - - ✓ `find_null_routes` - - ✓ `find_prefix` + +- ✓ `find_null_routes` +- ✓ `find_prefix` ### Route Table Context **Entry Command**: ✓ `set route-table` → enters `route-table` context **Show Commands**: - - ✓ `show routes` + +- ✓ `show routes` **Action Commands**: - - ✓ `find_null_routes` - - ✓ `find_prefix` + +- ✓ `find_null_routes` +- ✓ `find_prefix` ### Vpc Context **Entry Command**: ✓ `set vpc` → enters `vpc` context **Show Commands**: - - ✓ `show detail` - - ✓ `show endpoints` - - ✓ `show internet-gateways` - - ✓ `show nacls` - - ✓ `show nat-gateways` - - ✓ `show route-tables` - - ✓ `show security-groups` - - ✓ `show subnets` + +- ✓ `show detail` +- ✓ `show endpoints` +- ✓ `show internet-gateways` +- ✓ `show nacls` +- ✓ `show nat-gateways` +- ✓ `show route-tables` +- ✓ `show security-groups` +- ✓ `show subnets` **Action Commands**: - - ✓ `find_null_routes` - - ✓ `find_prefix` + +- ✓ `find_null_routes` +- ✓ `find_prefix` ### Transit Gateway Context **Entry Command**: ✓ `set transit-gateway` → enters `transit-gateway` context **Show Commands**: - - ✓ `show attachments` - - ✓ `show detail` - - ✓ `show route-tables` + +- ✓ `show attachments` +- ✓ `show detail` +- ✓ `show route-tables` **Action Commands**: - - ✓ `find_null_routes` - - ✓ `find_prefix` + +- ✓ `find_null_routes` +- ✓ `find_prefix` ### Firewall Context **Entry Command**: ✓ `set firewall` → enters `firewall` context **Show Commands**: - - ✓ `show detail` - - ✓ `show policy` - - ✓ `show rule-groups` + +- ✓ `show detail` +- ✓ `show policy` +- ✓ `show rule-groups` ### Ec2 Instance Context **Entry Command**: ✓ `set ec2-instance` → enters `ec2-instance` context **Show Commands**: - - ✓ `show detail` - - ✓ `show enis` - - ✓ `show routes` - - ✓ `show security-groups` + +- ✓ `show detail` +- ✓ `show enis` +- ✓ `show routes` +- ✓ `show security-groups` ### Elb Context **Entry Command**: ✓ `set elb` → enters `elb` context **Show Commands**: - - ✓ `show detail` - - ✓ `show health` - - ✓ `show listeners` - - ✓ `show targets` + +- ✓ `show detail` +- ✓ `show health` +- ✓ `show listeners` +- ✓ `show targets` ### Vpn Context **Entry Command**: ✓ `set vpn` → enters `vpn` context **Show Commands**: - - ✓ `show detail` - - ✓ `show tunnels` + +- ✓ `show detail` +- ✓ `show tunnels` ## Entity Relationships diff --git a/docs/command-hierarchy-split.md b/docs/command-hierarchy-split.md index 7a11955..d9cf453 100644 --- a/docs/command-hierarchy-split.md +++ b/docs/command-hierarchy-split.md @@ -8,6 +8,7 @@ This document shows the command hierarchy using multiple smaller, readable diagr ## Commands Overview ### Cache Management + - `clear_cache` - Clear all cached data (permanent) - `refresh [target|all]` - Refresh specific or all cached data - `refresh` - Refresh current context data @@ -16,11 +17,13 @@ This document shows the command hierarchy using multiple smaller, readable diagr - Available in all contexts for immediate cache invalidation ### Navigation + - `exit` - Go back one context level - `end` - Return to root level - `clear` - Clear the screen ### Resource Discovery + - `find_ip ` - Locate IP address across AWS resources - `find_prefix ` - Find routes matching CIDR prefix - `find_null_routes` - Show blackhole routes diff --git a/scripts/AUTOMATION_README.md b/scripts/AUTOMATION_README.md index ddf94de..de2641e 100644 --- a/scripts/AUTOMATION_README.md +++ b/scripts/AUTOMATION_README.md @@ -19,11 +19,13 @@ GitHub Issue → issue_investigator.py → Agent Prompt (XML) ## Components ### 1. issue_investigator.py (Existing) + - Fetches GitHub issues - Reproduces issue with shell_runner.py - Generates agent prompt with complete context ### 2. automated_issue_resolver.py (New) + - Orchestrates the full resolution workflow - Manages agent prompt execution - Creates validation tests @@ -31,6 +33,7 @@ GitHub Issue → issue_investigator.py → Agent Prompt (XML) - Creates PRs for successful fixes ### 3. Agent Prompts (Generated) + - XML format for AI agent consumption - Contains: - Issue description @@ -43,6 +46,7 @@ GitHub Issue → issue_investigator.py → Agent Prompt (XML) ## Usage ### Step 1: Generate Agent Prompt + ```bash # Generate prompt for Issue #9 uv run python scripts/issue_investigator.py --issue 9 --agent-prompt @@ -52,6 +56,7 @@ uv run python scripts/issue_investigator.py --issue 9 --agent-prompt > agent_pro ``` ### Step 2: Execute with AI Agent + ```bash # Manual: Copy XML prompt to Claude Code or AI agent # Automated: Use automated_issue_resolver.py @@ -60,6 +65,7 @@ uv run python scripts/automated_issue_resolver.py --issue 9 ``` ### Step 3: Validate Fix + ```bash # Run issue-specific test pytest tests/integration/workflows/issue_9_*.yaml -v @@ -69,6 +75,7 @@ pytest tests/integration/test_workflows.py -k "issue_9" ``` ### Step 4: Create PR (if tests pass) + ```bash # Automated in resolver script gh pr create --title "Fix Issue #9" --body "$(cat agent_prompts/issue_9_fix_summary.md)" diff --git a/scripts/README.md b/scripts/README.md index d246b66..061b649 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -125,6 +125,7 @@ The tool displays a formatted summary with status indicators: #### Agent Prompt (`--agent-prompt`) Generates structured prompts for AI agents. **XML is the default** and recommended for agents due to: + - Clear, unambiguous delimiters - Lower token overhead - Easier programmatic parsing @@ -284,6 +285,7 @@ uv run python scripts/shell_runner.py --debug "show vpns" "set vpn 1" "show tunn ``` **Debug Logging** (`--debug` or `-d`): + - **Purpose**: Capture comprehensive execution data for troubleshooting GitHub issues - **Log Location**: `/tmp/aws_net_runner_debug_.log` - **Includes**: @@ -418,12 +420,14 @@ graph TB Cleans terminal output for git commit messages or documentation. **Features**: + - Removes ANSI color codes - Converts box-drawing characters to ASCII - Normalizes whitespace - Optional compact mode (removes blank lines) **Usage**: + ```bash # From clipboard (macOS) pbpaste | python scripts/clean-output.py @@ -439,6 +443,7 @@ python scripts/clean-output.py < output.txt > cleaned.txt ``` **Example**: + ```bash # Before (with ANSI codes and box drawing) ┏━━━┳━━━━━━┳━━━━━━━━━━━┓ diff --git a/tests/README.md b/tests/README.md index 9de36fe..73d0527 100644 --- a/tests/README.md +++ b/tests/README.md @@ -28,21 +28,25 @@ tests/ ## Running Tests ### Quick Test (Core Framework) + ```bash pytest tests/test_command_graph/ tests/test_utils/ tests/unit/ -v ``` ### Full Test Suite + ```bash pytest tests/ -v ``` ### Specific Context Tests + ```bash pytest tests/test_command_graph/test_context_commands.py -v ``` ### Integration Tests (Real AWS) + ```bash AWS_PROFILE=your-profile pytest tests/integration/ -m integration ``` @@ -58,6 +62,7 @@ AWS_PROFILE=your-profile pytest tests/integration/ -m integration ## Key Components ### BaseContextTestCase + Reusable test helpers for command graph testing. ```python @@ -68,24 +73,29 @@ class MyTest(BaseContextTestCase): ``` ### Parametrized Tests + Efficient test generation using test_data_generator.py. ### pexpect Integration -Real CLI testing using aws-net-shell process. + +Real CLI testing using AWS-net-shell process. ## Automation ### Issue Investigation + ```bash uv run python scripts/issue_investigator.py --issue 9 --agent-prompt ``` ### Automated Resolution + ```bash uv run python scripts/automated_issue_resolver.py --issue 9 ``` ### Workflow Execution + ```bash uv run python scripts/shell_runner.py "show vpcs" "set vpc 1" "show subnets" ``` diff --git a/tests/fixtures/README.md b/tests/fixtures/README.md index 2b9f429..d274211 100644 --- a/tests/fixtures/README.md +++ b/tests/fixtures/README.md @@ -5,27 +5,33 @@ High-quality mock data for comprehensive testing without AWS resource deployment ## 📁 Available Fixtures ### Core Network Resources + - **`vpc.py`** - VPCs, Subnets, Route Tables, Security Groups, NACLs - **`tgw.py`** - Transit Gateways, Attachments, Route Tables, Peerings - **`cloudwan.py`** - Core Networks, Segments, Attachments, Policies, Routes - **`cloudwan_connect.py`** - Connect Peers, BGP Sessions, GRE Tunnels ### Compute & Network Interfaces + - **`ec2.py`** - EC2 Instances, ENIs (including Lambda, RDS, ALB ENIs) - **`elb.py`** - Application/Network Load Balancers, Target Groups, Listeners, Health Checks ### Hybrid Connectivity + - **`vpn.py`** - Site-to-Site VPN, Customer Gateways, VPN Gateways, Direct Connect - **`client_vpn.py`** - Client VPN Endpoints, Routes, Authorization Rules ### Gateway Resources + - **`gateways.py`** - Internet Gateways, NAT Gateways, Elastic IPs, Egress-only IGWs ### Security & Filtering + - **`firewall.py`** - Network Firewalls, Policies, Rule Groups (Suricata/5-tuple/Domain) - **`prefix_lists.py`** - Customer-Managed and AWS-Managed Prefix Lists ### Additional Services + - **`peering.py`** - VPC Peering Connections (intra-region, cross-region, cross-account) - **`vpc_endpoints.py`** - Interface/Gateway Endpoints, PrivateLink Services - **`route53_resolver.py`** - Resolver Endpoints, Rules, Query Logging @@ -235,6 +241,7 @@ Network Firewall ## 🎯 Best Practices ### 1. **Consistent ID Patterns** + ```python # Follow AWS ID patterns vpc_id = "vpc-0prod1234567890ab" # 17 hex chars after prefix @@ -243,6 +250,7 @@ tgw_id = "tgw-0prod12345678901" # 17 hex chars after prefix ``` ### 2. **Multi-Region Coverage** + ```python # Always include 3 regions minimum regions = ["eu-west-1", "us-east-1", "ap-southeast-2"] @@ -250,6 +258,7 @@ environments = ["production", "staging", "development"] ``` ### 3. **Cross-References** + ```python # Reference existing fixture IDs nat_gateway = { @@ -262,6 +271,7 @@ nat_gateway = { ``` ### 4. **Include All States** + ```python # Cover operational and transitional states states = ["available", "pending", "deleting", "deleted", "failed"] @@ -275,7 +285,9 @@ nat_failed = { ``` ### 5. **Helper Functions** + Every fixture file should include: + - `get_*_by_id()` - Primary ID lookup - `get_*s_by_vpc()` - VPC-scoped queries - `get_*s_by_state()` - State filtering @@ -283,6 +295,7 @@ Every fixture file should include: - Custom queries specific to resource type ### 6. **Docstrings** + ```python """Get comprehensive [resource] detail with all associated resources. @@ -337,18 +350,23 @@ def test_nat_gateway_has_eip(): ## 📚 Resources for Creating Fixtures ### AWS Documentation + Use AWS MCP server to read official docs: + ```python # In Claude Code: # "Read AWS documentation for [resource] API using MCP server" ``` ### Boto3 Documentation -- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html -- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/networkmanager.html + +- +- ### Module Source Code + Your modules are the **best reference** for expected data structure: + - `src/aws_network_tools/modules/[resource].py` - Look for `describe_*` boto3 API calls - Check what fields are accessed in display/processing code @@ -381,12 +399,14 @@ done ## 🎓 Learning Resources ### Understanding AWS API Responses + 1. **AWS CLI with --debug**: See raw API responses 2. **AWS MCP Server**: Fetch real resource details 3. **boto3 documentation**: Official API reference 4. **Module code**: Your own code shows expected structure ### Creating Realistic Test Data + 1. **Use realistic CIDR blocks**: RFC1918 private ranges 2. **Follow naming conventions**: Environment-tier-region-az pattern 3. **Include tags**: Name, Environment, ManagedBy, Purpose @@ -394,7 +414,9 @@ done 5. **Include failure cases**: Test error handling paths ### Cross-Reference Validation + Run this check to ensure fixture integrity: + ```python from tests.fixtures import get_all_gateways_by_vpc, VPC_FIXTURES diff --git a/tests/test_command_graph/README.md b/tests/test_command_graph/README.md index 5320f0b..2d81eac 100644 --- a/tests/test_command_graph/README.md +++ b/tests/test_command_graph/README.md @@ -1,9 +1,11 @@ # Command Graph Test Suite ## Overview + Comprehensive test suite for AWS Network Shell command graph testing with binary pass/fail validation. ## Test Structure + ``` tests/test_command_graph/ ├── README.md (this file) @@ -25,6 +27,7 @@ tests/test_command_graph/ ## Testing Methodology ### TDD Approach + 1. Write failing test first 2. Run test and capture failure 3. Implement minimal code to pass @@ -32,12 +35,15 @@ tests/test_command_graph/ 5. Move to next test ### Binary Pass/Fail + All tests use binary assertions: + - `assert result.exit_code == 0` for success - `assert result.exit_code != 0` for expected failures - `assert "expected_text" in result.output` for data validation ### Mock Strategy + - ALL boto3 calls are mocked - Use fixture data from `tests/fixtures/` - No real AWS API calls @@ -60,6 +66,7 @@ pytest tests/test_command_graph/test_top_level_commands.py::test_show_version -v ``` ## Test Coverage Goals + - 100% of HIERARCHY commands tested - All context transitions validated - All show/set command combinations @@ -67,4 +74,5 @@ pytest tests/test_command_graph/test_top_level_commands.py::test_show_version -v - Output format validation (table, json, yaml) ## Known Limitations + See `test_coverage_report.py` for commands that cannot be tested and justification. From 8be83e0f010356e53f85ae4bc74f0c8da33bd152 Mon Sep 17 00:00:00 2001 From: d-padmanabhan <88682350+d-padmanabhan@users.noreply.github.com> Date: Mon, 5 Jan 2026 09:11:54 -0500 Subject: [PATCH 3/4] chore(repo): remove personal info from config files --- .github/CODEOWNERS | 49 +-------------------- .github/{dependabot.yml => dependabot.yaml} | 2 - .gitignore | 3 ++ 3 files changed, 5 insertions(+), 49 deletions(-) rename .github/{dependabot.yml => dependabot.yaml} (97%) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6e19b93..175970a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -4,50 +4,5 @@ # Each line is a file pattern followed by one or more owners. # Order matters: later patterns override earlier ones. # Owners are notified when matching files are changed. - -# Default owner for everything in the repository -* @dpadmanabhan - -# Core framework and shell implementation -/src/aws_network_tools/core/ @dpadmanabhan -/src/aws_network_tools/shell/ @dpadmanabhan - -# AWS service modules -/src/aws_network_tools/modules/ @dpadmanabhan - -# Data models -/src/aws_network_tools/models/ @dpadmanabhan - -# Traceroute functionality -/src/aws_network_tools/traceroute/ @dpadmanabhan - -# CLI interfaces -/src/aws_network_tools/cli/ @dpadmanabhan -/src/aws_network_tools/cli.py @dpadmanabhan - -# Configuration -/src/aws_network_tools/config/ @dpadmanabhan - -# Tests -/tests/ @dpadmanabhan - -# Documentation -/docs/ @dpadmanabhan -*.md @dpadmanabhan - -# Build and packaging -pyproject.toml @dpadmanabhan -uv.lock @dpadmanabhan - -# CI/CD and automation -/.github/ @dpadmanabhan -/.pre-commit-config.yaml @dpadmanabhan - -# Scripts and utilities -/scripts/ @dpadmanabhan - -# Schemas -/schemas/ @dpadmanabhan - -# Themes -/themes/ @dpadmanabhan +# +# Update the GitHub username below to match the repository owner/maintainer. diff --git a/.github/dependabot.yml b/.github/dependabot.yaml similarity index 97% rename from .github/dependabot.yml rename to .github/dependabot.yaml index f66bd2f..18b869e 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yaml @@ -19,8 +19,6 @@ updates: labels: - dependencies - python - reviewers: - - dpadmanabhan groups: # Group minor and patch updates together python-minor: diff --git a/.gitignore b/.gitignore index c1c2ab9..a76423a 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,6 @@ hive-mind-prompt-*.txt # Benchmark data .benchmarks/ + +# Local extras folder +**extras/ From 74c6ba387cf5734eb11a55e7d115c6936bf03782 Mon Sep 17 00:00:00 2001 From: d-padmanabhan <88682350+d-padmanabhan@users.noreply.github.com> Date: Mon, 5 Jan 2026 09:21:04 -0500 Subject: [PATCH 4/4] chore(deps): change dependabot schedule to 00:00 UTC Monday --- .github/dependabot.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml index 18b869e..f78e576 100644 --- a/.github/dependabot.yaml +++ b/.github/dependabot.yaml @@ -11,8 +11,8 @@ updates: schedule: interval: weekly day: monday - time: "09:00" - timezone: America/New_York + time: "00:00" + timezone: UTC open-pull-requests-limit: 10 commit-message: prefix: "chore(deps)" @@ -43,8 +43,8 @@ updates: schedule: interval: weekly day: monday - time: "09:00" - timezone: America/New_York + time: "00:00" + timezone: UTC open-pull-requests-limit: 5 commit-message: prefix: "ci(deps)"