| 
 | 1 | +from unittest.mock import patch, MagicMock  | 
 | 2 | + | 
 | 3 | +from sqlalchemy.exc import SQLAlchemyError  | 
 | 4 | + | 
 | 5 | +from nettacker.config import Config  | 
 | 6 | +from nettacker.database.models import Base  | 
 | 7 | +from nettacker.database.mysql import mysql_create_database, mysql_create_tables  | 
 | 8 | +from tests.common import TestCase  | 
 | 9 | + | 
 | 10 | + | 
 | 11 | +class TestMySQLFunctions(TestCase):  | 
 | 12 | +    """Test cases for mysql.py functions"""  | 
 | 13 | + | 
 | 14 | +    @patch("nettacker.database.mysql.create_engine")  | 
 | 15 | +    def test_mysql_create_database_success(self, mock_create_engine):  | 
 | 16 | +        """Test successful database creation"""  | 
 | 17 | +        # Set up mock config  | 
 | 18 | +        Config.db = MagicMock()  | 
 | 19 | +        Config.db.as_dict.return_value = {  | 
 | 20 | +            "username": "test_user",  | 
 | 21 | +            "password": "test_pass",  | 
 | 22 | +            "host": "localhost",  | 
 | 23 | +            "port": "3306",  | 
 | 24 | +            "name": "test_db",  | 
 | 25 | +        }  | 
 | 26 | +        Config.db.name = "test_db"  | 
 | 27 | + | 
 | 28 | +        # Set up mock connection and execution  | 
 | 29 | +        mock_conn = MagicMock()  | 
 | 30 | +        mock_engine = MagicMock()  | 
 | 31 | +        mock_create_engine.return_value = mock_engine  | 
 | 32 | +        mock_engine.connect.return_value.__enter__.return_value = mock_conn  | 
 | 33 | + | 
 | 34 | +        # Mock database query results - database doesn't exist yet  | 
 | 35 | +        mock_conn.execute.return_value = [("mysql",), ("information_schema",)]  | 
 | 36 | + | 
 | 37 | +        # Call the function  | 
 | 38 | +        mysql_create_database()  | 
 | 39 | + | 
 | 40 | +        # Assertions  | 
 | 41 | +        mock_create_engine.assert_called_once_with(  | 
 | 42 | +            "mysql+pymysql://test_user:test_pass@localhost:3306"  | 
 | 43 | +        )  | 
 | 44 | + | 
 | 45 | +        # Check that execute was called with any text object that has the expected SQL  | 
 | 46 | +        call_args_list = mock_conn.execute.call_args_list  | 
 | 47 | +        self.assertEqual(len(call_args_list), 2)  # Two calls to execute  | 
 | 48 | + | 
 | 49 | +        # Check that the first call is SHOW DATABASES  | 
 | 50 | +        first_call_arg = call_args_list[0][0][0]  | 
 | 51 | +        self.assertEqual(str(first_call_arg), "SHOW DATABASES;")  | 
 | 52 | + | 
 | 53 | +        # Check that the second call is CREATE DATABASE  | 
 | 54 | +        second_call_arg = call_args_list[1][0][0]  | 
 | 55 | +        self.assertEqual(str(second_call_arg), "CREATE DATABASE test_db ")  | 
 | 56 | + | 
 | 57 | +    @patch("nettacker.database.mysql.create_engine")  | 
 | 58 | +    def test_mysql_create_database_already_exists(self, mock_create_engine):  | 
 | 59 | +        """Test when database already exists"""  | 
 | 60 | +        # Set up mock config  | 
 | 61 | +        Config.db = MagicMock()  | 
 | 62 | +        Config.db.as_dict.return_value = {  | 
 | 63 | +            "username": "test_user",  | 
 | 64 | +            "password": "test_pass",  | 
 | 65 | +            "host": "localhost",  | 
 | 66 | +            "port": "3306",  | 
 | 67 | +            "name": "test_db",  | 
 | 68 | +        }  | 
 | 69 | +        Config.db.name = "test_db"  | 
 | 70 | + | 
 | 71 | +        # Set up mock connection and execution  | 
 | 72 | +        mock_conn = MagicMock()  | 
 | 73 | +        mock_engine = MagicMock()  | 
 | 74 | +        mock_create_engine.return_value = mock_engine  | 
 | 75 | +        mock_engine.connect.return_value.__enter__.return_value = mock_conn  | 
 | 76 | + | 
 | 77 | +        # Mock database query results - database already exists  | 
 | 78 | +        mock_conn.execute.return_value = [("mysql",), ("information_schema",), ("test_db",)]  | 
 | 79 | + | 
 | 80 | +        # Call the function  | 
 | 81 | +        mysql_create_database()  | 
 | 82 | + | 
 | 83 | +        # Assertions  | 
 | 84 | +        mock_create_engine.assert_called_once_with(  | 
 | 85 | +            "mysql+pymysql://test_user:test_pass@localhost:3306"  | 
 | 86 | +        )  | 
 | 87 | + | 
 | 88 | +        # Check that execute was called once with SHOW DATABASES  | 
 | 89 | +        self.assertEqual(mock_conn.execute.call_count, 1)  | 
 | 90 | +        call_arg = mock_conn.execute.call_args[0][0]  | 
 | 91 | +        self.assertEqual(str(call_arg), "SHOW DATABASES;")  | 
 | 92 | + | 
 | 93 | +    @patch("nettacker.database.mysql.create_engine")  | 
 | 94 | +    def test_mysql_create_database_exception(self, mock_create_engine):  | 
 | 95 | +        """Test exception handling in create database"""  | 
 | 96 | +        # Set up mock config  | 
 | 97 | +        Config.db = MagicMock()  | 
 | 98 | +        Config.db.as_dict.return_value = {  | 
 | 99 | +            "username": "test_user",  | 
 | 100 | +            "password": "test_pass",  | 
 | 101 | +            "host": "localhost",  | 
 | 102 | +            "port": "3306",  | 
 | 103 | +            "name": "test_db",  | 
 | 104 | +        }  | 
 | 105 | + | 
 | 106 | +        # Set up mock to raise exception  | 
 | 107 | +        mock_engine = MagicMock()  | 
 | 108 | +        mock_create_engine.return_value = mock_engine  | 
 | 109 | +        mock_engine.connect.side_effect = SQLAlchemyError("Connection error")  | 
 | 110 | + | 
 | 111 | +        # Call the function (should not raise exception)  | 
 | 112 | +        with patch("builtins.print") as mock_print:  | 
 | 113 | +            mysql_create_database()  | 
 | 114 | +            mock_print.assert_called_once()  | 
 | 115 | + | 
 | 116 | +    @patch("nettacker.database.mysql.create_engine")  | 
 | 117 | +    def test_mysql_create_tables(self, mock_create_engine):  | 
 | 118 | +        """Test table creation function"""  | 
 | 119 | +        # Set up mock config  | 
 | 120 | +        Config.db = MagicMock()  | 
 | 121 | +        Config.db.as_dict.return_value = {  | 
 | 122 | +            "username": "test_user",  | 
 | 123 | +            "password": "test_pass",  | 
 | 124 | +            "host": "localhost",  | 
 | 125 | +            "port": "3306",  | 
 | 126 | +            "name": "test_db",  | 
 | 127 | +        }  | 
 | 128 | + | 
 | 129 | +        # Set up mock engine  | 
 | 130 | +        mock_engine = MagicMock()  | 
 | 131 | +        mock_create_engine.return_value = mock_engine  | 
 | 132 | + | 
 | 133 | +        # Call the function  | 
 | 134 | +        with patch.object(Base.metadata, "create_all") as mock_create_all:  | 
 | 135 | +            mysql_create_tables()  | 
 | 136 | + | 
 | 137 | +            # Assertions  | 
 | 138 | +            mock_create_engine.assert_called_once_with(  | 
 | 139 | +                "mysql+pymysql://test_user:test_pass@localhost:3306/test_db"  | 
 | 140 | +            )  | 
 | 141 | +            mock_create_all.assert_called_once_with(mock_engine)  | 
0 commit comments