diff --git a/tests/cli/commands/test_classify.py b/tests/cli/commands/test_classify.py new file mode 100644 index 0000000..ea76ae6 --- /dev/null +++ b/tests/cli/commands/test_classify.py @@ -0,0 +1,849 @@ +"""Tests for classify CLI command.""" + +import json +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from click.testing import CliRunner +from rich.console import Console + +from aieng_bot._cli.commands.classify import ( + _check_environment_variables, + _get_json_output, + _handle_merge_conflict, + _handle_no_failed_checks, + _output_json_format, + _output_result, + _output_rich_format, +) +from aieng_bot._cli.main import cli +from aieng_bot.classifier.models import ( + CheckFailure, + ClassificationResult, + FailureType, + PRContext, +) + + +class TestCheckEnvironmentVariables: + """Test suite for _check_environment_variables function.""" + + @patch.dict( + os.environ, + {"ANTHROPIC_API_KEY": "test-key", "GITHUB_TOKEN": "gh-token"}, + clear=True, + ) + def test_all_variables_set(self): + """Test when all required variables are set.""" + ok, missing = _check_environment_variables() + assert ok is True + assert len(missing) == 0 + + @patch.dict( + os.environ, + {"ANTHROPIC_API_KEY": "test-key", "GH_TOKEN": "gh-token"}, + clear=True, + ) + def test_gh_token_alternative(self): + """Test with GH_TOKEN instead of GITHUB_TOKEN.""" + ok, missing = _check_environment_variables() + assert ok is True + assert len(missing) == 0 + + @patch.dict(os.environ, {}, clear=True) + def test_no_variables_set(self): + """Test when no variables are set.""" + ok, missing = _check_environment_variables() + assert ok is False + assert len(missing) == 2 + assert any("ANTHROPIC_API_KEY" in m for m in missing) + assert any("GITHUB_TOKEN" in m for m in missing) + + @patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}, clear=True) + def test_missing_github_token(self): + """Test when only GitHub token is missing.""" + ok, missing = _check_environment_variables() + assert ok is False + assert len(missing) == 1 + assert "GITHUB_TOKEN" in missing[0] + + @patch.dict(os.environ, {"GITHUB_TOKEN": "gh-token"}, clear=True) + def test_missing_anthropic_key(self): + """Test when only Anthropic key is missing.""" + ok, missing = _check_environment_variables() + assert ok is False + assert len(missing) == 1 + assert "ANTHROPIC_API_KEY" in missing[0] + + @patch.dict(os.environ, {"ANTHROPIC_API_KEY": "", "GITHUB_TOKEN": ""}, clear=True) + def test_empty_variables(self): + """Test with empty string values.""" + ok, missing = _check_environment_variables() + assert ok is False + assert len(missing) == 2 + + +class TestGetJsonOutput: + """Test suite for _get_json_output function.""" + + def test_get_json_output_complete(self): + """Test JSON output with complete result.""" + result = ClassificationResult( + failure_type=FailureType.TEST, + confidence=0.95, + reasoning="Tests failed due to API changes", + failed_check_names=["unit-tests", "integration-tests"], + recommended_action="Update test assertions", + ) + + output = _get_json_output(result) + + assert output["failure_type"] == "test" + assert output["confidence"] == 0.95 + assert output["reasoning"] == "Tests failed due to API changes" + assert output["failed_check_names"] == ["unit-tests", "integration-tests"] + assert output["recommended_action"] == "Update test assertions" + + def test_get_json_output_unknown(self): + """Test JSON output with unknown failure type.""" + result = ClassificationResult( + failure_type=FailureType.UNKNOWN, + confidence=0.0, + reasoning="Cannot determine failure type", + failed_check_names=[], + recommended_action="Manual investigation required", + ) + + output = _get_json_output(result) + + assert output["failure_type"] == "unknown" + assert output["confidence"] == 0.0 + assert output["failed_check_names"] == [] + + def test_get_json_output_all_failure_types(self): + """Test JSON output for all failure types.""" + failure_types = [ + FailureType.TEST, + FailureType.LINT, + FailureType.SECURITY, + FailureType.BUILD, + FailureType.MERGE_CONFLICT, + FailureType.UNKNOWN, + ] + + for failure_type in failure_types: + result = ClassificationResult( + failure_type=failure_type, + confidence=0.9, + reasoning="Test reasoning", + failed_check_names=["check-1"], + recommended_action="Test action", + ) + output = _get_json_output(result) + assert output["failure_type"] == failure_type.value + + +class TestOutputJsonFormat: + """Test suite for _output_json_format function.""" + + def test_output_json_to_stdout(self): + """Test JSON output to stdout.""" + result = ClassificationResult( + failure_type=FailureType.LINT, + confidence=0.85, + reasoning="Linting errors", + failed_check_names=["eslint"], + recommended_action="Run linter", + ) + + console = Console() + with patch.object(console, "print_json") as mock_print: + _output_json_format(result, console, output_file=None) + mock_print.assert_called_once() + # Verify the data structure + call_data = mock_print.call_args[1]["data"] + assert call_data["failure_type"] == "lint" + assert call_data["confidence"] == 0.85 + + def test_output_json_to_file(self): + """Test JSON output to file.""" + result = ClassificationResult( + failure_type=FailureType.SECURITY, + confidence=0.92, + reasoning="Security vulnerability", + failed_check_names=["security-scan"], + recommended_action="Update dependencies", + ) + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f: + output_file = f.name + + try: + console = Console() + _output_json_format(result, console, output_file=output_file) + + # Verify file was created and contains correct data + assert Path(output_file).exists() + with open(output_file, "r") as f: + data = json.load(f) + + assert data["failure_type"] == "security" + assert data["confidence"] == 0.92 + assert data["reasoning"] == "Security vulnerability" + finally: + Path(output_file).unlink(missing_ok=True) + + def test_output_json_file_formatting(self): + """Test that JSON file is properly formatted with indent.""" + result = ClassificationResult( + failure_type=FailureType.BUILD, + confidence=0.88, + reasoning="Build failed", + failed_check_names=["build"], + recommended_action="Fix build errors", + ) + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f: + output_file = f.name + + try: + console = Console() + _output_json_format(result, console, output_file=output_file) + + # Read raw content to verify formatting + with open(output_file, "r") as f: + content = f.read() + + # Should have indentation and newlines (pretty-printed) + assert "\n" in content + assert " " in content # 2-space indent + finally: + Path(output_file).unlink(missing_ok=True) + + +class TestOutputRichFormat: + """Test suite for _output_rich_format function.""" + + def test_output_rich_format_successful(self): + """Test Rich format output for successful classification.""" + result = ClassificationResult( + failure_type=FailureType.TEST, + confidence=0.95, + reasoning="Unit tests failed", + failed_check_names=["pytest"], + recommended_action="Fix test assertions", + ) + + console = Console() + with patch.object(console, "print") as mock_print: + _output_rich_format(result, console) + # Should print multiple times (empty line, panel, empty line) + assert mock_print.call_count >= 3 + + def test_output_rich_format_unknown(self): + """Test Rich format output for unknown classification.""" + result = ClassificationResult( + failure_type=FailureType.UNKNOWN, + confidence=0.3, + reasoning="Cannot classify", + failed_check_names=[], + recommended_action="Manual review", + ) + + console = Console() + with patch.object(console, "print") as mock_print: + _output_rich_format(result, console) + assert mock_print.call_count >= 3 + + def test_output_rich_format_low_confidence(self): + """Test Rich format output for low confidence result.""" + result = ClassificationResult( + failure_type=FailureType.LINT, + confidence=0.6, # Below 0.7 threshold + reasoning="Might be linting", + failed_check_names=["check"], + recommended_action="Verify", + ) + + console = Console() + with patch.object(console, "print") as mock_print: + _output_rich_format(result, console) + assert mock_print.call_count >= 3 + + +class TestOutputResult: + """Test suite for _output_result function.""" + + def test_output_result_rich_format(self): + """Test routing to Rich format.""" + result = ClassificationResult( + failure_type=FailureType.TEST, + confidence=0.9, + reasoning="Test", + failed_check_names=[], + recommended_action="Fix", + ) + + console = Console() + with ( + patch("aieng_bot._cli.commands.classify._output_rich_format") as mock_rich, + patch("aieng_bot._cli.commands.classify._output_json_format") as mock_json, + ): + _output_result(result, console, json_output=False, output_file=None) + mock_rich.assert_called_once() + mock_json.assert_not_called() + + def test_output_result_json_format(self): + """Test routing to JSON format with --json flag.""" + result = ClassificationResult( + failure_type=FailureType.LINT, + confidence=0.85, + reasoning="Lint", + failed_check_names=[], + recommended_action="Fix", + ) + + console = Console() + with ( + patch("aieng_bot._cli.commands.classify._output_rich_format") as mock_rich, + patch("aieng_bot._cli.commands.classify._output_json_format") as mock_json, + ): + _output_result(result, console, json_output=True, output_file=None) + mock_json.assert_called_once() + mock_rich.assert_not_called() + + def test_output_result_json_file(self): + """Test routing to JSON format with output file.""" + result = ClassificationResult( + failure_type=FailureType.SECURITY, + confidence=0.95, + reasoning="Security", + failed_check_names=[], + recommended_action="Update", + ) + + console = Console() + with ( + patch("aieng_bot._cli.commands.classify._output_rich_format") as mock_rich, + patch("aieng_bot._cli.commands.classify._output_json_format") as mock_json, + ): + _output_result(result, console, json_output=False, output_file="out.json") + mock_json.assert_called_once_with(result, console, "out.json") + mock_rich.assert_not_called() + + +class TestHandleMergeConflict: + """Test suite for _handle_merge_conflict function.""" + + def test_handle_merge_conflict_rich_output(self): + """Test merge conflict handler with Rich output.""" + console = Console() + with ( + patch("aieng_bot._cli.commands.classify._output_result") as mock_output, + pytest.raises(SystemExit) as exc_info, + ): + _handle_merge_conflict(console, json_output=False, output_file=None) + + assert exc_info.value.code == 0 + mock_output.assert_called_once() + + # Verify the result passed to output + call_args = mock_output.call_args + result = call_args[0][0] + assert result.failure_type == FailureType.MERGE_CONFLICT + assert result.confidence == 1.0 + assert "merge conflicts" in result.reasoning.lower() + + def test_handle_merge_conflict_json_output(self): + """Test merge conflict handler with JSON output.""" + console = Console() + with ( + patch("aieng_bot._cli.commands.classify._output_result") as mock_output, + pytest.raises(SystemExit) as exc_info, + ): + _handle_merge_conflict(console, json_output=True, output_file=None) + + assert exc_info.value.code == 0 + call_args = mock_output.call_args + assert call_args[0][2] is True # json_output=True + + +class TestHandleNoFailedChecks: + """Test suite for _handle_no_failed_checks function.""" + + def test_handle_no_failed_checks_rich_output(self): + """Test no failed checks handler with Rich output.""" + console = Console() + with ( + patch("aieng_bot._cli.commands.classify._output_result") as mock_output, + pytest.raises(SystemExit) as exc_info, + ): + _handle_no_failed_checks(console, json_output=False, output_file=None) + + assert exc_info.value.code == 0 + mock_output.assert_called_once() + + # Verify the result + result = mock_output.call_args[0][0] + assert result.failure_type == FailureType.UNKNOWN + assert result.confidence == 0.0 + assert len(result.failed_check_names) == 0 + + def test_handle_no_failed_checks_json_output(self): + """Test no failed checks handler with JSON output.""" + console = Console() + with ( + patch("aieng_bot._cli.commands.classify._output_result") as mock_output, + pytest.raises(SystemExit) as exc_info, + ): + _handle_no_failed_checks(console, json_output=True, output_file="out.json") + + assert exc_info.value.code == 0 + call_args = mock_output.call_args + assert call_args[0][2] is True # json_output=True + assert call_args[0][3] == "out.json" # output_file + + +class TestClassifyCommand: + """Test suite for classify CLI command.""" + + @patch.dict( + os.environ, + {"ANTHROPIC_API_KEY": "test-key", "GITHUB_TOKEN": "gh-token"}, + clear=True, + ) + @patch("aieng_bot._cli.commands.classify.PRFailureClassifier") + @patch("aieng_bot._cli.commands.classify.GitHubClient") + def test_classify_successful(self, mock_github_client_class, mock_classifier_class): + """Test successful classification.""" + # Setup mocks + mock_github = MagicMock() + mock_github.check_merge_conflicts.return_value = False + mock_github.get_pr_details.return_value = PRContext( + repo="VectorInstitute/test-repo", + pr_number=123, + pr_title="Update deps", + pr_author="app/dependabot", + base_ref="main", + head_ref="update-branch", + ) + mock_github.get_failed_checks.return_value = [ + CheckFailure( + name="test-check", + conclusion="FAILURE", + workflow_name="CI", + details_url="https://github.com/.../runs/123", + started_at="2025-01-01T00:00:00Z", + completed_at="2025-01-01T00:05:00Z", + ) + ] + mock_github.get_failure_logs.return_value = "/tmp/logs.txt" + mock_github_client_class.return_value = mock_github + + mock_classifier = MagicMock() + mock_classifier.classify.return_value = ClassificationResult( + failure_type=FailureType.TEST, + confidence=0.95, + reasoning="Tests failed", + failed_check_names=["test-check"], + recommended_action="Fix tests", + ) + mock_classifier_class.return_value = mock_classifier + + runner = CliRunner() + result = runner.invoke( + cli, ["classify", "--repo", "VectorInstitute/test-repo", "--pr", "123"] + ) + + assert result.exit_code == 0 + mock_github.check_merge_conflicts.assert_called_once_with( + "VectorInstitute/test-repo", 123 + ) + mock_github.get_pr_details.assert_called_once() + mock_github.get_failed_checks.assert_called_once() + mock_classifier.classify.assert_called_once() + + @patch.dict( + os.environ, + {"ANTHROPIC_API_KEY": "test-key", "GITHUB_TOKEN": "gh-token"}, + clear=True, + ) + @patch("aieng_bot._cli.commands.classify.GitHubClient") + def test_classify_with_merge_conflicts(self, mock_github_client_class): + """Test classification with merge conflicts.""" + mock_github = MagicMock() + mock_github.check_merge_conflicts.return_value = True + mock_github_client_class.return_value = mock_github + + runner = CliRunner() + result = runner.invoke( + cli, ["classify", "--repo", "VectorInstitute/test-repo", "--pr", "123"] + ) + + assert result.exit_code == 0 + mock_github.check_merge_conflicts.assert_called_once() + # Should not fetch PR details or run classification + mock_github.get_pr_details.assert_not_called() + + @patch.dict( + os.environ, + {"ANTHROPIC_API_KEY": "test-key", "GITHUB_TOKEN": "gh-token"}, + clear=True, + ) + @patch("aieng_bot._cli.commands.classify.GitHubClient") + def test_classify_with_no_failed_checks(self, mock_github_client_class): + """Test classification with no failed checks.""" + mock_github = MagicMock() + mock_github.check_merge_conflicts.return_value = False + mock_github.get_pr_details.return_value = PRContext( + repo="VectorInstitute/test-repo", + pr_number=123, + pr_title="Update deps", + pr_author="app/dependabot", + base_ref="main", + head_ref="update-branch", + ) + mock_github.get_failed_checks.return_value = [] + mock_github_client_class.return_value = mock_github + + runner = CliRunner() + result = runner.invoke( + cli, ["classify", "--repo", "VectorInstitute/test-repo", "--pr", "123"] + ) + + assert result.exit_code == 0 + mock_github.get_failed_checks.assert_called_once() + + @patch.dict(os.environ, {}, clear=True) + def test_classify_missing_env_vars(self): + """Test classification with missing environment variables.""" + runner = CliRunner() + result = runner.invoke( + cli, ["classify", "--repo", "VectorInstitute/test-repo", "--pr", "123"] + ) + + assert result.exit_code == 1 + assert "Missing required environment variables" in result.output + + @patch.dict( + os.environ, + {"ANTHROPIC_API_KEY": "test-key", "GITHUB_TOKEN": "gh-token"}, + clear=True, + ) + @patch("aieng_bot._cli.commands.classify.PRFailureClassifier") + @patch("aieng_bot._cli.commands.classify.GitHubClient") + def test_classify_json_output( + self, mock_github_client_class, mock_classifier_class + ): + """Test classification with JSON output.""" + # Setup mocks + mock_github = MagicMock() + mock_github.check_merge_conflicts.return_value = False + mock_github.get_pr_details.return_value = PRContext( + repo="VectorInstitute/test-repo", + pr_number=123, + pr_title="Test PR", + pr_author="app/dependabot", + base_ref="main", + head_ref="test", + ) + mock_github.get_failed_checks.return_value = [ + CheckFailure( + name="check", + conclusion="FAILURE", + workflow_name="CI", + details_url="url", + started_at="time", + completed_at="time", + ) + ] + mock_github.get_failure_logs.return_value = "/tmp/logs.txt" + mock_github_client_class.return_value = mock_github + + mock_classifier = MagicMock() + mock_classifier.classify.return_value = ClassificationResult( + failure_type=FailureType.LINT, + confidence=0.85, + reasoning="Lint errors", + failed_check_names=["check"], + recommended_action="Fix", + ) + mock_classifier_class.return_value = mock_classifier + + runner = CliRunner() + result = runner.invoke( + cli, + [ + "classify", + "--repo", + "VectorInstitute/test-repo", + "--pr", + "123", + "--json", + ], + ) + + assert result.exit_code == 0 + # Output should be JSON + assert "lint" in result.output.lower() + + @patch.dict( + os.environ, + {"ANTHROPIC_API_KEY": "test-key", "GITHUB_TOKEN": "gh-token"}, + clear=True, + ) + @patch("aieng_bot._cli.commands.classify.PRFailureClassifier") + @patch("aieng_bot._cli.commands.classify.GitHubClient") + def test_classify_output_to_file( + self, mock_github_client_class, mock_classifier_class + ): + """Test classification with output to file.""" + # Setup mocks + mock_github = MagicMock() + mock_github.check_merge_conflicts.return_value = False + mock_github.get_pr_details.return_value = PRContext( + repo="VectorInstitute/test-repo", + pr_number=123, + pr_title="Test PR", + pr_author="app/dependabot", + base_ref="main", + head_ref="test", + ) + mock_github.get_failed_checks.return_value = [ + CheckFailure( + name="check", + conclusion="FAILURE", + workflow_name="CI", + details_url="url", + started_at="time", + completed_at="time", + ) + ] + mock_github.get_failure_logs.return_value = "/tmp/logs.txt" + mock_github_client_class.return_value = mock_github + + mock_classifier = MagicMock() + mock_classifier.classify.return_value = ClassificationResult( + failure_type=FailureType.BUILD, + confidence=0.9, + reasoning="Build failed", + failed_check_names=["check"], + recommended_action="Fix build", + ) + mock_classifier_class.return_value = mock_classifier + + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".json" + ) as temp_file: + output_file = temp_file.name + + try: + runner = CliRunner() + result = runner.invoke( + cli, + [ + "classify", + "--repo", + "VectorInstitute/test-repo", + "--pr", + "123", + "--output", + output_file, + ], + ) + + assert result.exit_code == 0 + assert Path(output_file).exists() + + # Verify file contents + with open(output_file, "r") as f: + data = json.load(f) + assert data["failure_type"] == "build" + assert data["confidence"] == 0.9 + finally: + Path(output_file).unlink(missing_ok=True) + + @patch.dict( + os.environ, + {"ANTHROPIC_API_KEY": "test-key", "GITHUB_TOKEN": "gh-token"}, + clear=True, + ) + @patch("aieng_bot._cli.commands.classify.PRFailureClassifier") + @patch("aieng_bot._cli.commands.classify.GitHubClient") + def test_classify_unknown_exits_with_error( + self, mock_github_client_class, mock_classifier_class + ): + """Test that unknown classification exits with error code.""" + # Setup mocks + mock_github = MagicMock() + mock_github.check_merge_conflicts.return_value = False + mock_github.get_pr_details.return_value = PRContext( + repo="VectorInstitute/test-repo", + pr_number=123, + pr_title="Test PR", + pr_author="app/dependabot", + base_ref="main", + head_ref="test", + ) + mock_github.get_failed_checks.return_value = [ + CheckFailure( + name="check", + conclusion="FAILURE", + workflow_name="CI", + details_url="url", + started_at="time", + completed_at="time", + ) + ] + mock_github.get_failure_logs.return_value = "/tmp/logs.txt" + mock_github_client_class.return_value = mock_github + + mock_classifier = MagicMock() + mock_classifier.classify.return_value = ClassificationResult( + failure_type=FailureType.UNKNOWN, + confidence=0.3, + reasoning="Cannot classify", + failed_check_names=["check"], + recommended_action="Manual review", + ) + mock_classifier_class.return_value = mock_classifier + + runner = CliRunner() + result = runner.invoke( + cli, ["classify", "--repo", "VectorInstitute/test-repo", "--pr", "123"] + ) + + # Should exit with error code for unknown + assert result.exit_code == 1 + + @patch.dict( + os.environ, + {"ANTHROPIC_API_KEY": "test-key", "GITHUB_TOKEN": "gh-token"}, + clear=True, + ) + @patch("aieng_bot._cli.commands.classify.GitHubClient") + def test_classify_github_api_error(self, mock_github_client_class): + """Test classification with GitHub API error.""" + mock_github_client_class.side_effect = ValueError("GitHub API error") + + runner = CliRunner() + result = runner.invoke( + cli, ["classify", "--repo", "VectorInstitute/test-repo", "--pr", "123"] + ) + + assert result.exit_code == 1 + assert "GitHub API error" in result.output + + @patch.dict( + os.environ, + {"ANTHROPIC_API_KEY": "test-key", "GITHUB_TOKEN": "gh-token"}, + clear=True, + ) + @patch("aieng_bot._cli.commands.classify.PRFailureClassifier") + @patch("aieng_bot._cli.commands.classify.GitHubClient") + def test_classify_unexpected_error( + self, mock_github_client_class, mock_classifier_class + ): + """Test classification with unexpected error.""" + mock_github = MagicMock() + mock_github.check_merge_conflicts.side_effect = Exception("Unexpected error") + mock_github_client_class.return_value = mock_github + + runner = CliRunner() + result = runner.invoke( + cli, ["classify", "--repo", "VectorInstitute/test-repo", "--pr", "123"] + ) + + assert result.exit_code == 1 + assert "Unexpected error" in result.output + + @patch.dict( + os.environ, + {"ANTHROPIC_API_KEY": "test-key", "GITHUB_TOKEN": "gh-token"}, + clear=True, + ) + @patch("aieng_bot._cli.commands.classify.PRFailureClassifier") + @patch("aieng_bot._cli.commands.classify.GitHubClient") + @patch("os.unlink") + def test_classify_cleans_up_temp_file( + self, mock_unlink, mock_github_client_class, mock_classifier_class + ): + """Test that temporary log file is cleaned up.""" + # Setup mocks + mock_github = MagicMock() + mock_github.check_merge_conflicts.return_value = False + mock_github.get_pr_details.return_value = PRContext( + repo="VectorInstitute/test-repo", + pr_number=123, + pr_title="Test", + pr_author="app/dependabot", + base_ref="main", + head_ref="test", + ) + mock_github.get_failed_checks.return_value = [ + CheckFailure( + name="check", + conclusion="FAILURE", + workflow_name="CI", + details_url="url", + started_at="time", + completed_at="time", + ) + ] + mock_github.get_failure_logs.return_value = "/tmp/test-logs.txt" + mock_github_client_class.return_value = mock_github + + mock_classifier = MagicMock() + mock_classifier.classify.return_value = ClassificationResult( + failure_type=FailureType.TEST, + confidence=0.9, + reasoning="Test", + failed_check_names=["check"], + recommended_action="Fix", + ) + mock_classifier_class.return_value = mock_classifier + + runner = CliRunner() + result = runner.invoke( + cli, ["classify", "--repo", "VectorInstitute/test-repo", "--pr", "123"] + ) + + assert result.exit_code == 0 + # Verify temp file cleanup was attempted + mock_unlink.assert_called_with("/tmp/test-logs.txt") + + @patch.dict( + os.environ, + {"ANTHROPIC_API_KEY": "test-key", "GITHUB_TOKEN": "gh-token"}, + clear=True, + ) + def test_classify_with_explicit_tokens(self): + """Test classification with explicitly provided tokens.""" + with patch("aieng_bot._cli.commands.classify.GitHubClient") as mock_gh: + mock_github = MagicMock() + mock_github.check_merge_conflicts.return_value = True + mock_gh.return_value = mock_github + + runner = CliRunner() + result = runner.invoke( + cli, + [ + "classify", + "--repo", + "VectorInstitute/test-repo", + "--pr", + "123", + "--github-token", + "explicit-gh-token", + "--anthropic-api-key", + "explicit-api-key", + ], + ) + + # Should exit successfully with merge conflict detected + assert result.exit_code == 0 + # Should use explicit tokens + mock_gh.assert_called_once_with(github_token="explicit-gh-token") diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py new file mode 100644 index 0000000..f2191e8 --- /dev/null +++ b/tests/cli/test_utils.py @@ -0,0 +1,533 @@ +"""Tests for CLI utility functions.""" + +import argparse +import json +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +from aieng_bot._cli.utils import get_version, parse_pr_inputs, read_failure_logs +from aieng_bot.classifier.models import CheckFailure, PRContext + + +class TestGetVersion: + """Test suite for get_version function.""" + + def test_get_version_installed(self): + """Test get_version returns version string when package is installed.""" + with patch("aieng_bot._cli.utils.version") as mock_version: + mock_version.return_value = "1.2.3" + result = get_version() + assert result == "1.2.3" + mock_version.assert_called_once_with("aieng-bot") + + def test_get_version_not_installed(self): + """Test get_version returns 'unknown' when package is not installed.""" + from importlib.metadata import ( # noqa: PLC0415 + PackageNotFoundError, + ) + + with patch("aieng_bot._cli.utils.version") as mock_version: + mock_version.side_effect = PackageNotFoundError() + result = get_version() + assert result == "unknown" + + def test_get_version_with_dev_version(self): + """Test get_version with development version.""" + with patch("aieng_bot._cli.utils.version") as mock_version: + mock_version.return_value = "0.4.0.dev0+g1234567" + result = get_version() + assert result == "0.4.0.dev0+g1234567" + + def test_get_version_with_rc_version(self): + """Test get_version with release candidate version.""" + with patch("aieng_bot._cli.utils.version") as mock_version: + mock_version.return_value = "2.0.0rc1" + result = get_version() + assert result == "2.0.0rc1" + + def test_get_version_calls_correct_package(self): + """Test that get_version queries the correct package name.""" + with patch("aieng_bot._cli.utils.version") as mock_version: + mock_version.return_value = "1.0.0" + get_version() + # Verify it queries "aieng-bot" not "aieng_bot" + mock_version.assert_called_with("aieng-bot") + + +class TestReadFailureLogs: + """Test suite for read_failure_logs function.""" + + def test_read_failure_logs_from_file(self): + """Test reading failure logs from file.""" + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".txt" + ) as temp_file: + temp_file.write("Error: Test failed\nStack trace here") + temp_file_path = temp_file.name + + try: + args = argparse.Namespace( + failure_logs_file=temp_file_path, failure_logs=None + ) + result = read_failure_logs(args) + + assert result == "Error: Test failed\nStack trace here" + finally: + Path(temp_file_path).unlink(missing_ok=True) + + def test_read_failure_logs_from_argument(self): + """Test reading failure logs from command-line argument.""" + args = argparse.Namespace( + failure_logs_file=None, failure_logs="Error from argument" + ) + result = read_failure_logs(args) + + assert result == "Error from argument" + + def test_read_failure_logs_prefers_file_over_argument(self): + """Test that file is preferred when both file and argument are provided.""" + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".txt" + ) as temp_file: + temp_file.write("Error from file") + temp_file_path = temp_file.name + + try: + args = argparse.Namespace( + failure_logs_file=temp_file_path, failure_logs="Error from argument" + ) + result = read_failure_logs(args) + + # Should prefer file over argument + assert result == "Error from file" + finally: + Path(temp_file_path).unlink(missing_ok=True) + + def test_read_failure_logs_file_not_found(self): + """Test reading failure logs when file doesn't exist.""" + args = argparse.Namespace( + failure_logs_file="/nonexistent/file.txt", failure_logs=None + ) + result = read_failure_logs(args) + + assert result == "" + + def test_read_failure_logs_neither_provided(self): + """Test reading failure logs when neither file nor argument is provided.""" + args = argparse.Namespace(failure_logs_file=None, failure_logs=None) + result = read_failure_logs(args) + + assert result == "" + + def test_read_failure_logs_empty_file(self): + """Test reading failure logs from empty file.""" + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".txt" + ) as temp_file: + temp_file_path = temp_file.name + # Write nothing - file is empty + + try: + args = argparse.Namespace( + failure_logs_file=temp_file_path, failure_logs=None + ) + result = read_failure_logs(args) + + assert result == "" + finally: + Path(temp_file_path).unlink(missing_ok=True) + + def test_read_failure_logs_large_file(self): + """Test reading failure logs from large file.""" + large_content = "Error line\n" * 10000 + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".txt" + ) as temp_file: + temp_file.write(large_content) + temp_file_path = temp_file.name + + try: + args = argparse.Namespace( + failure_logs_file=temp_file_path, failure_logs=None + ) + result = read_failure_logs(args) + + assert result == large_content + assert len(result) == len(large_content) + finally: + Path(temp_file_path).unlink(missing_ok=True) + + def test_read_failure_logs_with_special_characters(self): + """Test reading failure logs with special characters and unicode.""" + special_content = ( + "Error: Test failed 💥\n" + "Stack trace: \n" + " at función() línea 42\n" + " Encoding test: 中文 日本語 한국어" + ) + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".txt", encoding="utf-8" + ) as temp_file: + temp_file.write(special_content) + temp_file_path = temp_file.name + + try: + args = argparse.Namespace( + failure_logs_file=temp_file_path, failure_logs=None + ) + result = read_failure_logs(args) + + assert result == special_content + finally: + Path(temp_file_path).unlink(missing_ok=True) + + def test_read_failure_logs_empty_argument(self): + """Test reading failure logs from empty string argument.""" + args = argparse.Namespace(failure_logs_file=None, failure_logs="") + result = read_failure_logs(args) + + assert result == "" + + def test_read_failure_logs_whitespace_argument(self): + """Test reading failure logs from whitespace-only argument.""" + args = argparse.Namespace(failure_logs_file=None, failure_logs=" \n\t ") + result = read_failure_logs(args) + + assert result == " \n\t " + + +class TestParsePrInputs: + """Test suite for parse_pr_inputs function.""" + + def test_parse_pr_inputs_success(self): + """Test successful parsing of PR inputs.""" + pr_info = json.dumps( + { + "repo": "VectorInstitute/test-repo", + "pr_number": 123, + "pr_title": "Update dependencies", + "pr_author": "app/dependabot", + "base_ref": "main", + "head_ref": "dependabot/npm/package-1.0.0", + } + ) + + failed_checks = json.dumps( + [ + { + "name": "test-check", + "conclusion": "FAILURE", + "workflowName": "CI Tests", + "detailsUrl": "https://github.com/.../runs/123/job/456", + "startedAt": "2025-01-01T00:00:00Z", + "completedAt": "2025-01-01T00:05:00Z", + }, + { + "name": "lint-check", + "conclusion": "FAILURE", + "workflowName": "Linting", + "detailsUrl": "https://github.com/.../runs/124/job/457", + "startedAt": "2025-01-01T00:00:00Z", + "completedAt": "2025-01-01T00:03:00Z", + }, + ] + ) + + args = argparse.Namespace(pr_info=pr_info, failed_checks=failed_checks) + pr_context, checks = parse_pr_inputs(args) + + # Verify PR context + assert isinstance(pr_context, PRContext) + assert pr_context.repo == "VectorInstitute/test-repo" + assert pr_context.pr_number == 123 + assert pr_context.pr_title == "Update dependencies" + assert pr_context.pr_author == "app/dependabot" + assert pr_context.base_ref == "main" + assert pr_context.head_ref == "dependabot/npm/package-1.0.0" + + # Verify failed checks + assert len(checks) == 2 + assert all(isinstance(check, CheckFailure) for check in checks) + + assert checks[0].name == "test-check" + assert checks[0].conclusion == "FAILURE" + assert checks[0].workflow_name == "CI Tests" + assert "runs/123" in checks[0].details_url + + assert checks[1].name == "lint-check" + assert checks[1].conclusion == "FAILURE" + + def test_parse_pr_inputs_minimal_checks(self): + """Test parsing PR inputs with minimal check information.""" + pr_info = json.dumps( + { + "repo": "VectorInstitute/test-repo", + "pr_number": 123, + "pr_title": "Test PR", + "pr_author": "app/dependabot", + "base_ref": "main", + "head_ref": "test-branch", + } + ) + + # Checks with only required fields + failed_checks = json.dumps( + [ + { + "name": "test-check", + "conclusion": "FAILURE", + # Missing optional fields + } + ] + ) + + args = argparse.Namespace(pr_info=pr_info, failed_checks=failed_checks) + pr_context, checks = parse_pr_inputs(args) + + assert len(checks) == 1 + assert checks[0].name == "test-check" + assert checks[0].conclusion == "FAILURE" + assert checks[0].workflow_name == "" + assert checks[0].details_url == "" + assert checks[0].started_at == "" + assert checks[0].completed_at == "" + + def test_parse_pr_inputs_empty_checks_list(self): + """Test parsing PR inputs with empty checks list.""" + pr_info = json.dumps( + { + "repo": "VectorInstitute/test-repo", + "pr_number": 123, + "pr_title": "Test PR", + "pr_author": "app/dependabot", + "base_ref": "main", + "head_ref": "test-branch", + } + ) + + failed_checks = json.dumps([]) + + args = argparse.Namespace(pr_info=pr_info, failed_checks=failed_checks) + pr_context, checks = parse_pr_inputs(args) + + assert isinstance(pr_context, PRContext) + assert len(checks) == 0 + + def test_parse_pr_inputs_invalid_json_pr_info(self): + """Test parsing PR inputs with invalid JSON in pr_info.""" + args = argparse.Namespace( + pr_info="not valid json", failed_checks=json.dumps([]) + ) + + with pytest.raises(json.JSONDecodeError): + parse_pr_inputs(args) + + def test_parse_pr_inputs_invalid_json_failed_checks(self): + """Test parsing PR inputs with invalid JSON in failed_checks.""" + pr_info = json.dumps( + { + "repo": "VectorInstitute/test-repo", + "pr_number": 123, + "pr_title": "Test PR", + "pr_author": "app/dependabot", + "base_ref": "main", + "head_ref": "test-branch", + } + ) + + args = argparse.Namespace(pr_info=pr_info, failed_checks="not valid json") + + with pytest.raises(json.JSONDecodeError): + parse_pr_inputs(args) + + def test_parse_pr_inputs_missing_pr_fields(self): + """Test parsing PR inputs with missing required PR fields.""" + pr_info = json.dumps( + { + "repo": "VectorInstitute/test-repo", + "pr_number": 123, + # Missing required fields + } + ) + + failed_checks = json.dumps([]) + + args = argparse.Namespace(pr_info=pr_info, failed_checks=failed_checks) + + with pytest.raises(KeyError): + parse_pr_inputs(args) + + def test_parse_pr_inputs_pr_number_as_string(self): + """Test parsing PR inputs when pr_number is a string.""" + pr_info = json.dumps( + { + "repo": "VectorInstitute/test-repo", + "pr_number": "456", # String instead of int + "pr_title": "Test PR", + "pr_author": "app/dependabot", + "base_ref": "main", + "head_ref": "test-branch", + } + ) + + failed_checks = json.dumps([]) + + args = argparse.Namespace(pr_info=pr_info, failed_checks=failed_checks) + pr_context, checks = parse_pr_inputs(args) + + # Should convert string to int + assert pr_context.pr_number == 456 + assert isinstance(pr_context.pr_number, int) + + def test_parse_pr_inputs_multiple_checks_same_run(self): + """Test parsing multiple checks from the same workflow run.""" + pr_info = json.dumps( + { + "repo": "VectorInstitute/test-repo", + "pr_number": 123, + "pr_title": "Test PR", + "pr_author": "app/dependabot", + "base_ref": "main", + "head_ref": "test-branch", + } + ) + + failed_checks = json.dumps( + [ + { + "name": "test-job-1", + "conclusion": "FAILURE", + "workflowName": "CI Tests", + "detailsUrl": "https://github.com/.../runs/123/job/456", + "startedAt": "2025-01-01T00:00:00Z", + "completedAt": "2025-01-01T00:05:00Z", + }, + { + "name": "test-job-2", + "conclusion": "FAILURE", + "workflowName": "CI Tests", + "detailsUrl": "https://github.com/.../runs/123/job/457", + "startedAt": "2025-01-01T00:00:00Z", + "completedAt": "2025-01-01T00:05:00Z", + }, + ] + ) + + args = argparse.Namespace(pr_info=pr_info, failed_checks=failed_checks) + pr_context, checks = parse_pr_inputs(args) + + assert len(checks) == 2 + assert checks[0].name == "test-job-1" + assert checks[1].name == "test-job-2" + # Both from same run + assert "runs/123" in checks[0].details_url + assert "runs/123" in checks[1].details_url + + def test_parse_pr_inputs_with_special_characters(self): + """Test parsing PR inputs with special characters in strings.""" + pr_info = json.dumps( + { + "repo": "VectorInstitute/test-repo", + "pr_number": 123, + "pr_title": "Fix bug in función() with emoji 🐛", + "pr_author": "app/dependabot", + "base_ref": "main", + "head_ref": "fix/bug-with-特殊字符", + } + ) + + failed_checks = json.dumps([]) + + args = argparse.Namespace(pr_info=pr_info, failed_checks=failed_checks) + pr_context, checks = parse_pr_inputs(args) + + assert "emoji 🐛" in pr_context.pr_title + assert "特殊字符" in pr_context.head_ref + + def test_parse_pr_inputs_pre_commit_ci_author(self): + """Test parsing PR inputs with pre-commit-ci bot author.""" + pr_info = json.dumps( + { + "repo": "VectorInstitute/test-repo", + "pr_number": 789, + "pr_title": "[pre-commit.ci] auto fixes", + "pr_author": "app/pre-commit-ci", + "base_ref": "main", + "head_ref": "pre-commit-ci-update", + } + ) + + failed_checks = json.dumps([]) + + args = argparse.Namespace(pr_info=pr_info, failed_checks=failed_checks) + pr_context, checks = parse_pr_inputs(args) + + assert pr_context.pr_author == "app/pre-commit-ci" + assert "[pre-commit.ci]" in pr_context.pr_title + + def test_parse_pr_inputs_check_with_null_fields(self): + """Test parsing checks with null values in optional fields.""" + pr_info = json.dumps( + { + "repo": "VectorInstitute/test-repo", + "pr_number": 123, + "pr_title": "Test PR", + "pr_author": "app/dependabot", + "base_ref": "main", + "head_ref": "test-branch", + } + ) + + failed_checks = json.dumps( + [ + { + "name": "test-check", + "conclusion": "FAILURE", + "workflowName": None, + "detailsUrl": None, + "startedAt": None, + "completedAt": None, + } + ] + ) + + args = argparse.Namespace(pr_info=pr_info, failed_checks=failed_checks) + pr_context, checks = parse_pr_inputs(args) + + # Should use get() with default empty strings for null values + assert len(checks) == 1 + # None values should be converted to empty strings + assert checks[0].workflow_name in (None, "") + assert checks[0].details_url in (None, "") + + def test_parse_pr_inputs_preserves_check_order(self): + """Test that parse_pr_inputs preserves the order of checks.""" + pr_info = json.dumps( + { + "repo": "VectorInstitute/test-repo", + "pr_number": 123, + "pr_title": "Test PR", + "pr_author": "app/dependabot", + "base_ref": "main", + "head_ref": "test-branch", + } + ) + + failed_checks = json.dumps( + [ + {"name": "check-3", "conclusion": "FAILURE"}, + {"name": "check-1", "conclusion": "FAILURE"}, + {"name": "check-2", "conclusion": "FAILURE"}, + ] + ) + + args = argparse.Namespace(pr_info=pr_info, failed_checks=failed_checks) + pr_context, checks = parse_pr_inputs(args) + + # Order should be preserved + assert checks[0].name == "check-3" + assert checks[1].name == "check-1" + assert checks[2].name == "check-2" diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..69e211f --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1 @@ +"""Tests for utilities module.""" diff --git a/tests/utils/test_github_client.py b/tests/utils/test_github_client.py new file mode 100644 index 0000000..bc9bc3a --- /dev/null +++ b/tests/utils/test_github_client.py @@ -0,0 +1,765 @@ +"""Tests for GitHub API client.""" + +import json +import os +import subprocess +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from aieng_bot.classifier.models import CheckFailure, PRContext +from aieng_bot.utils.github_client import GitHubClient + + +class TestGitHubClientInit: + """Test suite for GitHubClient initialization.""" + + def test_init_with_explicit_token(self): + """Test initialization with explicit token.""" + client = GitHubClient(github_token="test-token-123") + assert client.github_token == "test-token-123" + + @patch.dict(os.environ, {"GITHUB_TOKEN": "env-token-456"}, clear=True) + def test_init_with_github_token_env(self): + """Test initialization with GITHUB_TOKEN from environment.""" + client = GitHubClient() + assert client.github_token == "env-token-456" + + @patch.dict(os.environ, {"GH_TOKEN": "gh-token-789"}, clear=True) + def test_init_with_gh_token_env(self): + """Test initialization with GH_TOKEN from environment.""" + client = GitHubClient() + assert client.github_token == "gh-token-789" + + @patch.dict( + os.environ, {"GITHUB_TOKEN": "github-token", "GH_TOKEN": "gh-token"}, clear=True + ) + def test_init_prefers_github_token_env(self): + """Test that GITHUB_TOKEN is preferred over GH_TOKEN.""" + client = GitHubClient() + assert client.github_token == "github-token" + + @patch.dict(os.environ, {}, clear=True) + def test_init_without_token_raises_error(self): + """Test that initialization fails without token.""" + with pytest.raises( + ValueError, + match="GitHub token not found. Please set GITHUB_TOKEN or GH_TOKEN", + ): + GitHubClient() + + def test_explicit_token_overrides_env(self): + """Test that explicit token overrides environment variables.""" + with patch.dict(os.environ, {"GITHUB_TOKEN": "env-token"}, clear=True): + client = GitHubClient(github_token="explicit-token") + assert client.github_token == "explicit-token" + + +class TestGitHubClientRunGhCommand: + """Test suite for _run_gh_command method.""" + + @patch("subprocess.run") + def test_run_gh_command_success(self, mock_run): + """Test successful gh command execution.""" + mock_result = MagicMock() + mock_result.stdout = "success output" + mock_result.stderr = "" + mock_result.returncode = 0 + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + result = client._run_gh_command(["api", "repos/VectorInstitute/test"]) + + assert result.stdout == "success output" + assert result.returncode == 0 + mock_run.assert_called_once() + + # Verify gh CLI is called with correct arguments + call_args = mock_run.call_args + assert call_args[0][0] == ["gh", "api", "repos/VectorInstitute/test"] + + # Verify environment contains token + assert call_args[1]["env"]["GH_TOKEN"] == "test-token" + + @patch("subprocess.run") + def test_run_gh_command_with_check_false(self, mock_run): + """Test gh command execution with check=False.""" + mock_result = MagicMock() + mock_result.stdout = "" + mock_result.stderr = "error" + mock_result.returncode = 1 + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + result = client._run_gh_command(["api", "test"], check=False) + + # Should not raise even with non-zero exit code + assert result.returncode == 1 + assert result.stderr == "error" + + @patch("subprocess.run") + def test_run_gh_command_timeout(self, mock_run): + """Test gh command timeout.""" + mock_run.side_effect = subprocess.TimeoutExpired(cmd="gh api test", timeout=120) + + client = GitHubClient(github_token="test-token") + with pytest.raises(subprocess.TimeoutExpired): + client._run_gh_command(["api", "test"]) + + @patch("subprocess.run") + def test_run_gh_command_not_found(self, mock_run): + """Test gh CLI not installed.""" + mock_run.side_effect = FileNotFoundError("gh: command not found") + + client = GitHubClient(github_token="test-token") + with pytest.raises(FileNotFoundError): + client._run_gh_command(["api", "test"]) + + @patch("subprocess.run") + def test_run_gh_command_timeout_value(self, mock_run): + """Test that timeout is set to 120 seconds.""" + mock_result = MagicMock() + mock_result.returncode = 0 + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + client._run_gh_command(["api", "test"]) + + # Verify timeout is set + call_kwargs = mock_run.call_args[1] + assert call_kwargs["timeout"] == 120 + + @patch("subprocess.run") + def test_run_gh_command_preserves_environment(self, mock_run): + """Test that existing environment variables are preserved.""" + mock_result = MagicMock() + mock_result.returncode = 0 + mock_run.return_value = mock_result + + with patch.dict(os.environ, {"PATH": "/usr/bin", "HOME": "/home/user"}): + client = GitHubClient(github_token="test-token") + client._run_gh_command(["api", "test"]) + + call_kwargs = mock_run.call_args[1] + assert "PATH" in call_kwargs["env"] + assert "HOME" in call_kwargs["env"] + assert call_kwargs["env"]["GH_TOKEN"] == "test-token" + + +class TestGitHubClientGetPRDetails: + """Test suite for get_pr_details method.""" + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_pr_details_success(self, mock_run): + """Test successful PR details retrieval.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps( + { + "title": "Update dependencies", + "author": {"login": "app/dependabot"}, + "headRefName": "dependabot/npm/package-1.0.0", + "baseRefName": "main", + "mergeable": "MERGEABLE", + } + ) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + pr_context = client.get_pr_details("VectorInstitute/test-repo", 123) + + assert isinstance(pr_context, PRContext) + assert pr_context.repo == "VectorInstitute/test-repo" + assert pr_context.pr_number == 123 + assert pr_context.pr_title == "Update dependencies" + assert pr_context.pr_author == "app/dependabot" + assert pr_context.head_ref == "dependabot/npm/package-1.0.0" + assert pr_context.base_ref == "main" + + # Verify gh CLI was called with correct arguments + mock_run.assert_called_once() + call_args = mock_run.call_args[0][0] + assert "pr" in call_args + assert "view" in call_args + assert "123" in call_args + assert "--repo" in call_args + assert "VectorInstitute/test-repo" in call_args + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_pr_details_invalid_json(self, mock_run): + """Test PR details with invalid JSON response.""" + mock_result = MagicMock() + mock_result.stdout = "not valid json" + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + with pytest.raises(json.JSONDecodeError): + client.get_pr_details("VectorInstitute/test-repo", 123) + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_pr_details_missing_fields(self, mock_run): + """Test PR details with missing required fields.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps( + { + "title": "Update dependencies", + # Missing author, headRefName, etc. + } + ) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + with pytest.raises(KeyError): + client.get_pr_details("VectorInstitute/test-repo", 123) + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_pr_details_api_error(self, mock_run): + """Test PR details with API error.""" + mock_run.side_effect = subprocess.CalledProcessError( + 1, "gh", stderr="API rate limit exceeded" + ) + + client = GitHubClient(github_token="test-token") + with pytest.raises(subprocess.CalledProcessError): + client.get_pr_details("VectorInstitute/test-repo", 123) + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_pr_details_with_pre_commit_ci_bot(self, mock_run): + """Test PR details with pre-commit-ci bot author.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps( + { + "title": "[pre-commit.ci] auto fixes", + "author": {"login": "app/pre-commit-ci"}, + "headRefName": "pre-commit-ci-update", + "baseRefName": "main", + "mergeable": "MERGEABLE", + } + ) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + pr_context = client.get_pr_details("VectorInstitute/test-repo", 456) + + assert pr_context.pr_author == "app/pre-commit-ci" + assert pr_context.pr_title == "[pre-commit.ci] auto fixes" + + +class TestGitHubClientGetFailedChecks: + """Test suite for get_failed_checks method.""" + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failed_checks_success(self, mock_run): + """Test successful failed checks retrieval.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps( + { + "statusCheckRollup": [ + { + "name": "test-check", + "conclusion": "FAILURE", + "workflowName": "CI Tests", + "detailsUrl": "https://github.com/.../runs/123/job/456", + "startedAt": "2025-01-01T00:00:00Z", + "completedAt": "2025-01-01T00:05:00Z", + }, + { + "name": "lint-check", + "conclusion": "SUCCESS", + "workflowName": "Linting", + "detailsUrl": "https://github.com/.../runs/124/job/457", + "startedAt": "2025-01-01T00:00:00Z", + "completedAt": "2025-01-01T00:03:00Z", + }, + { + "name": "security-check", + "conclusion": "FAILURE", + "workflowName": "Security Audit", + "detailsUrl": "https://github.com/.../runs/125/job/458", + "startedAt": "2025-01-01T00:00:00Z", + "completedAt": "2025-01-01T00:10:00Z", + }, + ] + } + ) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + failed_checks = client.get_failed_checks("VectorInstitute/test-repo", 123) + + assert len(failed_checks) == 2 + assert all(isinstance(check, CheckFailure) for check in failed_checks) + + # Check first failure + assert failed_checks[0].name == "test-check" + assert failed_checks[0].conclusion == "FAILURE" + assert failed_checks[0].workflow_name == "CI Tests" + assert "runs/123" in failed_checks[0].details_url + + # Check second failure + assert failed_checks[1].name == "security-check" + assert failed_checks[1].conclusion == "FAILURE" + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failed_checks_no_failures(self, mock_run): + """Test get_failed_checks when all checks pass.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps( + { + "statusCheckRollup": [ + { + "name": "test-check", + "conclusion": "SUCCESS", + "workflowName": "CI Tests", + "detailsUrl": "https://github.com/.../runs/123/job/456", + "startedAt": "2025-01-01T00:00:00Z", + "completedAt": "2025-01-01T00:05:00Z", + } + ] + } + ) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + failed_checks = client.get_failed_checks("VectorInstitute/test-repo", 123) + + assert len(failed_checks) == 0 + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failed_checks_empty_rollup(self, mock_run): + """Test get_failed_checks with empty status check rollup.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps({"statusCheckRollup": []}) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + failed_checks = client.get_failed_checks("VectorInstitute/test-repo", 123) + + assert len(failed_checks) == 0 + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failed_checks_no_rollup(self, mock_run): + """Test get_failed_checks when statusCheckRollup is missing.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps({}) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + failed_checks = client.get_failed_checks("VectorInstitute/test-repo", 123) + + assert len(failed_checks) == 0 + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failed_checks_missing_optional_fields(self, mock_run): + """Test get_failed_checks with missing optional fields.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps( + { + "statusCheckRollup": [ + { + "conclusion": "FAILURE", + # Missing name, workflowName, detailsUrl, etc. + } + ] + } + ) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + failed_checks = client.get_failed_checks("VectorInstitute/test-repo", 123) + + assert len(failed_checks) == 1 + assert failed_checks[0].name == "" + assert failed_checks[0].workflow_name == "" + assert failed_checks[0].details_url == "" + + +class TestGitHubClientGetFailureLogs: + """Test suite for get_failure_logs method.""" + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failure_logs_success(self, mock_run): + """Test successful failure logs extraction.""" + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "Error: Test failed\nAssertion error at line 42" + mock_run.return_value = mock_result + + failed_checks = [ + CheckFailure( + name="test-check", + conclusion="FAILURE", + workflow_name="CI Tests", + details_url="https://github.com/VectorInstitute/test-repo/actions/runs/123/job/456", + started_at="2025-01-01T00:00:00Z", + completed_at="2025-01-01T00:05:00Z", + ) + ] + + client = GitHubClient(github_token="test-token") + logs_file = client.get_failure_logs("VectorInstitute/test-repo", failed_checks) + + try: + assert Path(logs_file).exists() + logs_content = Path(logs_file).read_text() + assert "test-check" in logs_content + assert "Error: Test failed" in logs_content + assert "run 123" in logs_content + finally: + Path(logs_file).unlink(missing_ok=True) + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failure_logs_multiple_checks(self, mock_run): + """Test failure logs extraction from multiple checks.""" + mock_result1 = MagicMock() + mock_result1.returncode = 0 + mock_result1.stdout = "Error: Test failed" + + mock_result2 = MagicMock() + mock_result2.returncode = 0 + mock_result2.stdout = "Error: Lint failed" + + mock_run.side_effect = [mock_result1, mock_result2] + + failed_checks = [ + CheckFailure( + name="test-check", + conclusion="FAILURE", + workflow_name="CI Tests", + details_url="https://github.com/VectorInstitute/test-repo/actions/runs/123/job/456", + started_at="2025-01-01T00:00:00Z", + completed_at="2025-01-01T00:05:00Z", + ), + CheckFailure( + name="lint-check", + conclusion="FAILURE", + workflow_name="Linting", + details_url="https://github.com/VectorInstitute/test-repo/actions/runs/124/job/457", + started_at="2025-01-01T00:00:00Z", + completed_at="2025-01-01T00:03:00Z", + ), + ] + + client = GitHubClient(github_token="test-token") + logs_file = client.get_failure_logs("VectorInstitute/test-repo", failed_checks) + + try: + logs_content = Path(logs_file).read_text() + assert "test-check" in logs_content + assert "lint-check" in logs_content + assert "Error: Test failed" in logs_content + assert "Error: Lint failed" in logs_content + finally: + Path(logs_file).unlink(missing_ok=True) + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failure_logs_deduplicates_runs(self, mock_run): + """Test that same run ID is only fetched once.""" + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "Error: Multiple jobs failed" + mock_run.return_value = mock_result + + # Multiple checks from the same run (different jobs) + failed_checks = [ + CheckFailure( + name="test-job-1", + conclusion="FAILURE", + workflow_name="CI Tests", + details_url="https://github.com/VectorInstitute/test-repo/actions/runs/123/job/456", + started_at="2025-01-01T00:00:00Z", + completed_at="2025-01-01T00:05:00Z", + ), + CheckFailure( + name="test-job-2", + conclusion="FAILURE", + workflow_name="CI Tests", + details_url="https://github.com/VectorInstitute/test-repo/actions/runs/123/job/457", + started_at="2025-01-01T00:00:00Z", + completed_at="2025-01-01T00:05:00Z", + ), + ] + + client = GitHubClient(github_token="test-token") + logs_file = client.get_failure_logs("VectorInstitute/test-repo", failed_checks) + + try: + # Should only call gh once (deduplication) + assert mock_run.call_count == 1 + logs_content = Path(logs_file).read_text() + # Only one set of logs should be written + assert logs_content.count("Error: Multiple jobs failed") == 1 + finally: + Path(logs_file).unlink(missing_ok=True) + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failure_logs_no_details_url(self, mock_run): + """Test failure logs when check has no details URL.""" + failed_checks = [ + CheckFailure( + name="test-check", + conclusion="FAILURE", + workflow_name="CI Tests", + details_url="", # Empty URL + started_at="2025-01-01T00:00:00Z", + completed_at="2025-01-01T00:05:00Z", + ) + ] + + client = GitHubClient(github_token="test-token") + logs_file = client.get_failure_logs("VectorInstitute/test-repo", failed_checks) + + try: + logs_content = Path(logs_file).read_text() + assert "No failure logs could be extracted" in logs_content + # Should not call gh CLI + mock_run.assert_not_called() + finally: + Path(logs_file).unlink(missing_ok=True) + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failure_logs_invalid_url_format(self, mock_run): + """Test failure logs with invalid URL format.""" + failed_checks = [ + CheckFailure( + name="test-check", + conclusion="FAILURE", + workflow_name="CI Tests", + details_url="https://github.com/invalid/url", # No run ID + started_at="2025-01-01T00:00:00Z", + completed_at="2025-01-01T00:05:00Z", + ) + ] + + client = GitHubClient(github_token="test-token") + logs_file = client.get_failure_logs("VectorInstitute/test-repo", failed_checks) + + try: + logs_content = Path(logs_file).read_text() + assert "No failure logs could be extracted" in logs_content + mock_run.assert_not_called() + finally: + Path(logs_file).unlink(missing_ok=True) + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failure_logs_api_error(self, mock_run): + """Test failure logs when API call fails.""" + mock_result = MagicMock() + mock_result.returncode = 1 + mock_result.stderr = "API rate limit exceeded" + mock_run.return_value = mock_result + + failed_checks = [ + CheckFailure( + name="test-check", + conclusion="FAILURE", + workflow_name="CI Tests", + details_url="https://github.com/VectorInstitute/test-repo/actions/runs/123/job/456", + started_at="2025-01-01T00:00:00Z", + completed_at="2025-01-01T00:05:00Z", + ) + ] + + client = GitHubClient(github_token="test-token") + logs_file = client.get_failure_logs("VectorInstitute/test-repo", failed_checks) + + try: + logs_content = Path(logs_file).read_text() + assert "No failure logs could be extracted" in logs_content + finally: + Path(logs_file).unlink(missing_ok=True) + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failure_logs_handles_exceptions(self, mock_run): + """Test failure logs handles unexpected exceptions.""" + mock_run.side_effect = Exception("Unexpected error") + + failed_checks = [ + CheckFailure( + name="test-check", + conclusion="FAILURE", + workflow_name="CI Tests", + details_url="https://github.com/VectorInstitute/test-repo/actions/runs/123/job/456", + started_at="2025-01-01T00:00:00Z", + completed_at="2025-01-01T00:05:00Z", + ) + ] + + client = GitHubClient(github_token="test-token") + logs_file = client.get_failure_logs("VectorInstitute/test-repo", failed_checks) + + try: + logs_content = Path(logs_file).read_text() + assert "No failure logs could be extracted" in logs_content + finally: + Path(logs_file).unlink(missing_ok=True) + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failure_logs_empty_checks_list(self, mock_run): + """Test failure logs with empty checks list.""" + client = GitHubClient(github_token="test-token") + logs_file = client.get_failure_logs("VectorInstitute/test-repo", []) + + try: + logs_content = Path(logs_file).read_text() + assert "No failure logs could be extracted" in logs_content + mock_run.assert_not_called() + finally: + Path(logs_file).unlink(missing_ok=True) + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failure_logs_creates_temp_file(self, mock_run): + """Test that failure logs creates a temporary file.""" + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = "Error logs" + mock_run.return_value = mock_result + + failed_checks = [ + CheckFailure( + name="test-check", + conclusion="FAILURE", + workflow_name="CI Tests", + details_url="https://github.com/VectorInstitute/test-repo/actions/runs/123/job/456", + started_at="2025-01-01T00:00:00Z", + completed_at="2025-01-01T00:05:00Z", + ) + ] + + client = GitHubClient(github_token="test-token") + logs_file = client.get_failure_logs("VectorInstitute/test-repo", failed_checks) + + try: + assert logs_file.startswith(tempfile.gettempdir()) + assert "failure-logs-" in logs_file + assert logs_file.endswith(".txt") + finally: + Path(logs_file).unlink(missing_ok=True) + + @patch.object(GitHubClient, "_run_gh_command") + def test_get_failure_logs_large_output(self, mock_run): + """Test failure logs with large output.""" + # Generate large log output (5MB) + large_logs = "Error line\n" * 100000 + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = large_logs + mock_run.return_value = mock_result + + failed_checks = [ + CheckFailure( + name="test-check", + conclusion="FAILURE", + workflow_name="CI Tests", + details_url="https://github.com/VectorInstitute/test-repo/actions/runs/123/job/456", + started_at="2025-01-01T00:00:00Z", + completed_at="2025-01-01T00:05:00Z", + ) + ] + + client = GitHubClient(github_token="test-token") + logs_file = client.get_failure_logs("VectorInstitute/test-repo", failed_checks) + + try: + logs_content = Path(logs_file).read_text() + # Verify all logs are written (no truncation) + assert large_logs in logs_content + assert len(logs_content) > len(large_logs) # Includes headers + finally: + Path(logs_file).unlink(missing_ok=True) + + +class TestGitHubClientCheckMergeConflicts: + """Test suite for check_merge_conflicts method.""" + + @patch.object(GitHubClient, "_run_gh_command") + def test_check_merge_conflicts_has_conflicts(self, mock_run): + """Test checking merge conflicts when PR has conflicts.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps({"mergeable": "CONFLICTING"}) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + has_conflicts = client.check_merge_conflicts("VectorInstitute/test-repo", 123) + + assert has_conflicts is True + + @patch.object(GitHubClient, "_run_gh_command") + def test_check_merge_conflicts_no_conflicts(self, mock_run): + """Test checking merge conflicts when PR is mergeable.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps({"mergeable": "MERGEABLE"}) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + has_conflicts = client.check_merge_conflicts("VectorInstitute/test-repo", 123) + + assert has_conflicts is False + + @patch.object(GitHubClient, "_run_gh_command") + def test_check_merge_conflicts_unknown_status(self, mock_run): + """Test checking merge conflicts with unknown mergeable status.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps({"mergeable": "UNKNOWN"}) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + has_conflicts = client.check_merge_conflicts("VectorInstitute/test-repo", 123) + + assert has_conflicts is False + + @patch.object(GitHubClient, "_run_gh_command") + def test_check_merge_conflicts_empty_status(self, mock_run): + """Test checking merge conflicts with empty mergeable status.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps({"mergeable": ""}) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + has_conflicts = client.check_merge_conflicts("VectorInstitute/test-repo", 123) + + assert has_conflicts is False + + @patch.object(GitHubClient, "_run_gh_command") + def test_check_merge_conflicts_missing_field(self, mock_run): + """Test checking merge conflicts when mergeable field is missing.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps({}) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + has_conflicts = client.check_merge_conflicts("VectorInstitute/test-repo", 123) + + assert has_conflicts is False + + @patch.object(GitHubClient, "_run_gh_command") + def test_check_merge_conflicts_api_error(self, mock_run): + """Test checking merge conflicts with API error.""" + mock_run.side_effect = subprocess.CalledProcessError( + 1, "gh", stderr="API error" + ) + + client = GitHubClient(github_token="test-token") + with pytest.raises(subprocess.CalledProcessError): + client.check_merge_conflicts("VectorInstitute/test-repo", 123) + + @patch.object(GitHubClient, "_run_gh_command") + def test_check_merge_conflicts_verifies_call_args(self, mock_run): + """Test that check_merge_conflicts calls gh with correct arguments.""" + mock_result = MagicMock() + mock_result.stdout = json.dumps({"mergeable": "MERGEABLE"}) + mock_run.return_value = mock_result + + client = GitHubClient(github_token="test-token") + client.check_merge_conflicts("VectorInstitute/test-repo", 123) + + mock_run.assert_called_once() + call_args = mock_run.call_args[0][0] + assert "pr" in call_args + assert "view" in call_args + assert "123" in call_args + assert "--repo" in call_args + assert "VectorInstitute/test-repo" in call_args + assert "--json" in call_args + assert "mergeable" in call_args