Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions mysql_ch_replicator/binlog_replicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os.path
import json
import random
import re

from enum import Enum
from logging import getLogger
Expand Down Expand Up @@ -379,6 +380,48 @@ def clear_old_binlog_if_required(self):
self.last_binlog_clear_time = curr_time
self.data_writer.remove_old_files(curr_time - BinlogReplicator.BINLOG_RETENTION_PERIOD)

@classmethod
def _try_parse_db_name_from_query(cls, query: str) -> str:
"""
Extract the database name from a MySQL CREATE TABLE or ALTER TABLE query.
Supports multiline queries and quoted identifiers that may include special characters.

Examples:
- CREATE TABLE `mydb`.`mytable` ( ... )
- ALTER TABLE mydb.mytable ADD COLUMN id int NOT NULL
- CREATE TABLE IF NOT EXISTS mydb.mytable ( ... )
- ALTER TABLE "mydb"."mytable" ...
- CREATE TABLE IF NOT EXISTS `multidb` . `multitable` ( ... )
- CREATE TABLE `replication-test_db`.`test_table_2` ( ... )

Returns the database name, or an empty string if not found.
"""
# Updated regex:
# 1. Matches optional leading whitespace.
# 2. Matches "CREATE TABLE" or "ALTER TABLE" (with optional IF NOT EXISTS).
# 3. Optionally captures a database name, which can be either:
# - Quoted (using backticks or double quotes) and may contain special characters.
# - Unquoted (letters, digits, and underscores only).
# 4. Allows optional whitespace around the separating dot.
# 5. Matches the table name (which we do not capture).
pattern = re.compile(
r'^\s*' # optional leading whitespace/newlines
r'(?i:(?:create|alter))\s+table\s+' # "CREATE TABLE" or "ALTER TABLE"
r'(?:if\s+not\s+exists\s+)?' # optional "IF NOT EXISTS"
# Optional DB name group: either quoted or unquoted, followed by optional whitespace, a dot, and more optional whitespace.
r'(?:(?:[`"](?P<dbname_quoted>[^`"]+)[`"]|(?P<dbname_unquoted>[a-zA-Z0-9_]+))\s*\.\s*)?'
r'[`"]?[a-zA-Z0-9_]+[`"]?', # table name (quoted or not)
re.IGNORECASE | re.DOTALL # case-insensitive, dot matches newline
)

m = pattern.search(query)
if m:
# Return the quoted db name if found; else return the unquoted name if found.
if m.group('dbname_quoted'):
return m.group('dbname_quoted')
elif m.group('dbname_unquoted'):
return m.group('dbname_unquoted')
return ''

def run(self):
last_transaction_id = None
Expand Down Expand Up @@ -425,12 +468,6 @@ def run(self):
if isinstance(log_event.db_name, bytes):
log_event.db_name = log_event.db_name.decode('utf-8')

if not self.settings.is_database_matches(log_event.db_name):
continue

logger.debug(f'event matched {transaction_id}, {log_event.db_name}, {log_event.table_name}')

log_event.transaction_id = transaction_id
if isinstance(event, UpdateRowsEvent) or isinstance(event, WriteRowsEvent):
log_event.event_type = EventType.ADD_EVENT.value

Expand All @@ -440,6 +477,21 @@ def run(self):
if isinstance(event, QueryEvent):
log_event.event_type = EventType.QUERY.value

if log_event.event_type == EventType.UNKNOWN.value:
continue

if log_event.event_type == EventType.QUERY.value:
db_name_from_query = self._try_parse_db_name_from_query(event.query)
if db_name_from_query:
log_event.db_name = db_name_from_query

if not self.settings.is_database_matches(log_event.db_name):
continue

logger.debug(f'event matched {transaction_id}, {log_event.db_name}, {log_event.table_name}')

log_event.transaction_id = transaction_id

if isinstance(event, QueryEvent):
log_event.records = event.query
else:
Expand Down
74 changes: 65 additions & 9 deletions test_mysql_ch_replicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mysql_ch_replicator import config
from mysql_ch_replicator import mysql_api
from mysql_ch_replicator import clickhouse_api
from mysql_ch_replicator.binlog_replicator import State as BinlogState, FileReader, EventType
from mysql_ch_replicator.binlog_replicator import State as BinlogState, FileReader, EventType, BinlogReplicator
from mysql_ch_replicator.db_replicator import State as DbReplicatorState, DbReplicator
from mysql_ch_replicator.converter import MysqlToClickhouseConverter

Expand Down Expand Up @@ -69,14 +69,16 @@ def prepare_env(
cfg: config.Settings,
mysql: mysql_api.MySQLApi,
ch: clickhouse_api.ClickhouseApi,
db_name: str = TEST_DB_NAME
db_name: str = TEST_DB_NAME,
set_mysql_db: bool = True
):
if os.path.exists(cfg.binlog_replicator.data_dir):
shutil.rmtree(cfg.binlog_replicator.data_dir)
os.mkdir(cfg.binlog_replicator.data_dir)
mysql.drop_database(db_name)
mysql.create_database(db_name)
mysql.set_database(db_name)
if set_mysql_db:
mysql.set_database(db_name)
ch.drop_database(db_name)
assert_wait(lambda: db_name not in ch.get_databases())

Expand Down Expand Up @@ -784,7 +786,7 @@ def _get_last_insert_name():
f"INSERT INTO `{TEST_TABLE_NAME}` (name, age) "
f"VALUES ('TEST_VALUE_{i}_{base_value}', {i});", commit=i % 20 == 0,
)

#`replication-test_db`
mysql.execute(f"INSERT INTO `{TEST_TABLE_NAME}` (name, age) VALUES ('TEST_VALUE_FINAL', 0);", commit=True)

print("running db_replicator")
Expand Down Expand Up @@ -823,12 +825,12 @@ def test_different_types_1():
clickhouse_settings=cfg.clickhouse,
)

prepare_env(cfg, mysql, ch)
prepare_env(cfg, mysql, ch, set_mysql_db=False)

mysql.execute("SET sql_mode = 'ALLOW_INVALID_DATES';")

mysql.execute(f'''
CREATE TABLE `{TEST_TABLE_NAME}` (
CREATE TABLE `{TEST_DB_NAME}`.`{TEST_TABLE_NAME}` (
`id` int unsigned NOT NULL AUTO_INCREMENT,
name varchar(255),
`employee` int unsigned NOT NULL,
Expand Down Expand Up @@ -866,7 +868,7 @@ def test_different_types_1():
''')

mysql.execute(
f"INSERT INTO `{TEST_TABLE_NAME}` (name, modified_date) VALUES ('Ivan', '0000-00-00 00:00:00');",
f"INSERT INTO `{TEST_DB_NAME}`.`{TEST_TABLE_NAME}` (name, modified_date) VALUES ('Ivan', '0000-00-00 00:00:00');",
commit=True,
)

Expand All @@ -883,15 +885,30 @@ def test_different_types_1():
assert_wait(lambda: len(ch.select(TEST_TABLE_NAME)) == 1)

mysql.execute(
f"INSERT INTO `{TEST_TABLE_NAME}` (name, modified_date) VALUES ('Alex', '0000-00-00 00:00:00');",
f"INSERT INTO `{TEST_DB_NAME}`.`{TEST_TABLE_NAME}` (name, modified_date) VALUES ('Alex', '0000-00-00 00:00:00');",
commit=True,
)
mysql.execute(
f"INSERT INTO `{TEST_TABLE_NAME}` (name, modified_date) VALUES ('Givi', '2023-01-08 03:11:09');",
f"INSERT INTO `{TEST_DB_NAME}`.`{TEST_TABLE_NAME}` (name, modified_date) VALUES ('Givi', '2023-01-08 03:11:09');",
commit=True,
)
assert_wait(lambda: len(ch.select(TEST_TABLE_NAME)) == 3)

mysql.execute(f'''
CREATE TABLE `{TEST_DB_NAME}`.`{TEST_TABLE_NAME_2}` (
`id` int unsigned NOT NULL AUTO_INCREMENT,
name varchar(255),
PRIMARY KEY (id)
);
''')

mysql.execute(
f"INSERT INTO `{TEST_DB_NAME}`.`{TEST_TABLE_NAME_2}` (name) VALUES ('Ivan');",
commit=True,
)

assert_wait(lambda: TEST_TABLE_NAME_2 in ch.get_tables())

db_replicator_runner.stop()
binlog_replicator_runner.stop()

Expand Down Expand Up @@ -1535,3 +1552,42 @@ def test_alter_tokens_split():
print("Match? ", result == expected)
print("-" * 60)
assert result == expected


@pytest.mark.parametrize("query,expected", [
("CREATE TABLE `mydb`.`mytable` (id INT)", "mydb"),
("CREATE TABLE mydb.mytable (id INT)", "mydb"),
("ALTER TABLE `mydb`.mytable ADD COLUMN name VARCHAR(50)", "mydb"),
("CREATE TABLE IF NOT EXISTS mydb.mytable (id INT)", "mydb"),
("CREATE TABLE mytable (id INT)", ""),
(" CREATE TABLE `mydb` . `mytable` \n ( id INT )", "mydb"),
('ALTER TABLE "testdb"."tablename" ADD COLUMN flag BOOLEAN', "testdb"),
("create table mydb.mytable (id int)", "mydb"),
("DROP DATABASE mydb", ""),
("CREATE TABLE mydbmytable (id int)", ""), # missing dot between DB and table
("""
CREATE TABLE IF NOT EXISTS
`multidb`
.
`multitable`
(
id INT,
name VARCHAR(100)
)
""", "multidb"),
("""
ALTER TABLE
`justtable`
ADD COLUMN age INT;
""", ""),
("""
CREATE TABLE `replication-test_db`.`test_table_2` (
`id` int unsigned NOT NULL AUTO_INCREMENT,
name varchar(255),
PRIMARY KEY (id)
)
""", "replication-test_db"),
("BEGIN", ""),
])
def test_parse_db_name_from_query(query, expected):
assert BinlogReplicator._try_parse_db_name_from_query(query) == expected