|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +import importlib.util |
5 | 6 | from typing import TYPE_CHECKING, Any |
6 | 7 |
|
| 8 | +from ..exceptions import MissingDriverError |
7 | 9 | from ..schema import get_default_port |
8 | 10 | from .base import MySQLBaseAdapter, import_driver_module |
9 | 11 |
|
10 | 12 | if TYPE_CHECKING: |
11 | 13 | from ...config import ConnectionConfig |
12 | 14 |
|
13 | 15 |
|
| 16 | +def _check_old_mysql_connector() -> bool: |
| 17 | + """Check if the old mysql-connector-python package is installed.""" |
| 18 | + return importlib.util.find_spec("mysql.connector") is not None |
| 19 | + |
| 20 | + |
14 | 21 | class MySQLAdapter(MySQLBaseAdapter): |
15 | 22 | """Adapter for MySQL using PyMySQL.""" |
16 | 23 |
|
@@ -56,12 +63,29 @@ def driver_import_names(self) -> tuple[str, ...]: |
56 | 63 |
|
57 | 64 | def connect(self, config: ConnectionConfig) -> Any: |
58 | 65 | """Connect to MySQL database.""" |
59 | | - pymysql = import_driver_module( |
60 | | - "pymysql", |
61 | | - driver_name=self.name, |
62 | | - extra_name=self.install_extra, |
63 | | - package_name=self.install_package, |
64 | | - ) |
| 66 | + try: |
| 67 | + pymysql = import_driver_module( |
| 68 | + "pymysql", |
| 69 | + driver_name=self.name, |
| 70 | + extra_name=self.install_extra, |
| 71 | + package_name=self.install_package, |
| 72 | + ) |
| 73 | + except MissingDriverError: |
| 74 | + # Check if user has the old mysql-connector-python installed |
| 75 | + if _check_old_mysql_connector(): |
| 76 | + raise MissingDriverError( |
| 77 | + self.name, |
| 78 | + self.install_extra, |
| 79 | + self.install_package, |
| 80 | + module_name="pymysql", |
| 81 | + import_error=( |
| 82 | + "MySQL driver has changed from mysql-connector-python to PyMySQL.\n" |
| 83 | + "Please uninstall the old package and install PyMySQL:\n" |
| 84 | + " pip uninstall mysql-connector-python\n" |
| 85 | + " pip install PyMySQL" |
| 86 | + ), |
| 87 | + ) from None |
| 88 | + raise |
65 | 89 |
|
66 | 90 | port = int(config.port or get_default_port("mysql")) |
67 | 91 | return pymysql.connect( |
|
0 commit comments