diff --git a/mysql_ch_replicator/clickhouse_api.py b/mysql_ch_replicator/clickhouse_api.py index 3e5eac1..3a6916c 100644 --- a/mysql_ch_replicator/clickhouse_api.py +++ b/mysql_ch_replicator/clickhouse_api.py @@ -16,8 +16,7 @@ ( {fields}, `_version` UInt64, - INDEX _version _version TYPE minmax GRANULARITY 1, - INDEX idx_id {primary_key} TYPE bloom_filter GRANULARITY 1 + {indexes} ) ENGINE = ReplacingMergeTree(_version) {partition_by}ORDER BY {primary_key} @@ -25,7 +24,7 @@ ''' DELETE_QUERY = ''' -DELETE FROM {db_name}.{table_name} WHERE {field_name} IN ({field_values}) +DELETE FROM {db_name}.{table_name} WHERE ({field_name}) IN ({field_values}) ''' @@ -63,8 +62,6 @@ def get_databases(self): return database_list def execute_command(self, query): - #print(' === executing ch query', query) - for attempt in range(ClickhouseApi.MAX_RETRIES): try: self.client.command(query) @@ -76,7 +73,6 @@ def execute_command(self, query): time.sleep(ClickhouseApi.RETRY_INTERVAL) def recreate_database(self): - #print(' === creating database', self.database) self.execute_command(f'DROP DATABASE IF EXISTS {self.database}') self.execute_command(f'CREATE DATABASE {self.database}') @@ -87,31 +83,39 @@ def set_last_used_version(self, table_name, last_used_version): self.tables_last_record_version[table_name] = last_used_version def create_table(self, structure: TableStructure): - if not structure.primary_key: + if not structure.primary_keys: raise Exception(f'missing primary key for {structure.table_name}') - primary_key_type = '' - for field in structure.fields: - if field.name == structure.primary_key: - primary_key_type = field.field_type - if not primary_key_type: - raise Exception(f'failed to get type of primary key {structure.table_name} {structure.primary_key}') - fields = [ f' `{field.name}` {field.field_type}' for field in structure.fields ] fields = ',\n'.join(fields) partition_by = '' - if 'int' in primary_key_type.lower(): - partition_by = f'PARTITION BY intDiv({structure.primary_key}, 4294967)\n' + if len(structure.primary_keys) == 1: + if 'int' in structure.fields[structure.primary_key_ids[0]].field_type.lower(): + partition_by = f'PARTITION BY intDiv({structure.primary_keys[0]}, 4294967)\n' + + indexes = [ + 'INDEX _version _version TYPE minmax GRANULARITY 1', + ] + if len(structure.primary_keys) == 1: + indexes.append( + f'INDEX idx_id {structure.primary_keys[0]} TYPE bloom_filter GRANULARITY 1', + ) + + indexes = ',\n'.join(indexes) + primary_key = ','.join(structure.primary_keys) + if len(structure.primary_keys) > 1: + primary_key = f'({primary_key})' query = CREATE_TABLE_QUERY.format(**{ 'db_name': self.database, 'table_name': structure.table_name, 'fields': fields, - 'primary_key': structure.primary_key, + 'primary_key': primary_key, 'partition_by': partition_by, + 'indexes': indexes, }) self.execute_command(query) @@ -161,6 +165,7 @@ def insert(self, table_name, records, table_structure: TableStructure = None): self.set_last_used_version(table_name, current_version) def erase(self, table_name, field_name, field_values): + field_name = ','.join(field_name) field_values = ', '.join(list(map(str, field_values))) query = DELETE_QUERY.format(**{ 'db_name': self.database, diff --git a/mysql_ch_replicator/converter.py b/mysql_ch_replicator/converter.py index a2dd586..0f9381d 100644 --- a/mysql_ch_replicator/converter.py +++ b/mysql_ch_replicator/converter.py @@ -2,7 +2,7 @@ import json import sqlparse import re -from pyparsing import Word, alphas, alphanums +from pyparsing import Suppress, CaselessKeyword, Word, alphas, alphanums, delimitedList from .table_structure import TableStructure, TableField @@ -218,7 +218,7 @@ def convert_table_structure(self, mysql_structure: TableStructure) -> TableStruc name=field.name, field_type=clickhouse_field_type, )) - clickhouse_structure.primary_key = mysql_structure.primary_key + clickhouse_structure.primary_keys = mysql_structure.primary_keys clickhouse_structure.preprocess() return clickhouse_structure @@ -521,9 +521,22 @@ def parse_mysql_table_structure(self, create_statement, required_table_name=None if line.lower().startswith('constraint'): continue if line.lower().startswith('primary key'): - pattern = 'PRIMARY KEY (' + Word(alphanums + '_`') + ')' + # Define identifier to match column names, handling backticks and unquoted names + identifier = (Suppress('`') + Word(alphas + alphanums + '_') + Suppress('`')) | Word( + alphas + alphanums + '_') + + # Build the parsing pattern + pattern = CaselessKeyword('PRIMARY') + CaselessKeyword('KEY') + Suppress('(') + delimitedList( + identifier)('column_names') + Suppress(')') + + # Parse the line result = pattern.parseString(line) - structure.primary_key = strip_sql_name(result[1]) + + # Extract and process the primary key column names + primary_keys = [strip_sql_name(name) for name in result['column_names']] + + structure.primary_keys = primary_keys + continue #print(" === processing line", line) @@ -543,16 +556,16 @@ def parse_mysql_table_structure(self, create_statement, required_table_name=None #print(' ---- params:', field_parameters) - if not structure.primary_key: + if not structure.primary_keys: for field in structure.fields: if 'primary key' in field.parameters.lower(): - structure.primary_key = field.name + structure.primary_keys.append(field.name) - if not structure.primary_key: + if not structure.primary_keys: if structure.has_field('id'): - structure.primary_key = 'id' + structure.primary_keys = ['id'] - if not structure.primary_key: + if not structure.primary_keys: raise Exception(f'No primary key for table {structure.table_name}, {create_statement}') structure.preprocess() diff --git a/mysql_ch_replicator/db_replicator.py b/mysql_ch_replicator/db_replicator.py index d7ab85e..94178ed 100644 --- a/mysql_ch_replicator/db_replicator.py +++ b/mysql_ch_replicator/db_replicator.py @@ -148,15 +148,16 @@ def validate_database_settings(self): ) def validate_mysql_structure(self, mysql_structure: TableStructure): - primary_field: TableField = mysql_structure.fields[mysql_structure.primary_key_idx] - if 'not null' not in primary_field.parameters.lower(): - logger.warning('primary key validation failed') - logger.warning( - f'\n\n\n !!! WARNING - PRIMARY KEY NULLABLE (field "{primary_field.name}", table "{mysql_structure.table_name}") !!!\n\n' - 'There could be errors replicating nullable primary key\n' - 'Please ensure all tables has NOT NULL parameter for primary key\n' - 'Or mark tables as skipped, see "exclude_tables" option\n\n\n' - ) + for key_idx in mysql_structure.primary_key_ids: + primary_field: TableField = mysql_structure.fields[key_idx] + if 'not null' not in primary_field.parameters.lower(): + logger.warning('primary key validation failed') + logger.warning( + f'\n\n\n !!! WARNING - PRIMARY KEY NULLABLE (field "{primary_field.name}", table "{mysql_structure.table_name}") !!!\n\n' + 'There could be errors replicating nullable primary key\n' + 'Please ensure all tables has NOT NULL parameter for primary key\n' + 'Or mark tables as skipped, see "exclude_tables" option\n\n\n' + ) def run(self): try: @@ -276,29 +277,33 @@ def perform_initial_replication_table(self, table_name): logger.debug(f'mysql table structure: {mysql_table_structure}') logger.debug(f'clickhouse table structure: {clickhouse_table_structure}') - field_names = [field.name for field in clickhouse_table_structure.fields] field_types = [field.field_type for field in clickhouse_table_structure.fields] - primary_key = clickhouse_table_structure.primary_key - primary_key_index = field_names.index(primary_key) - primary_key_type = field_types[primary_key_index] + primary_keys = clickhouse_table_structure.primary_keys + primary_key_ids = clickhouse_table_structure.primary_key_ids + primary_key_types = [field_types[key_idx] for key_idx in primary_key_ids] - logger.debug(f'primary key name: {primary_key}, type: {primary_key_type}') + #logger.debug(f'primary key name: {primary_key}, type: {primary_key_type}') stats_number_of_records = 0 last_stats_dump_time = time.time() while True: - query_start_value = max_primary_key - if 'int' not in primary_key_type.lower() and query_start_value is not None: - query_start_value = f"'{query_start_value}'" + query_start_values = max_primary_key + if query_start_values is not None: + for i in range(len(query_start_values)): + key_type = primary_key_types[i] + value = query_start_values[i] + if 'int' not in key_type.lower(): + value = f"'{value}'" + query_start_values[i] = value records = self.mysql_api.get_records( table_name=table_name, - order_by=primary_key, + order_by=primary_keys, limit=DbReplicator.INITIAL_REPLICATION_BATCH_SIZE, - start_value=query_start_value, + start_value=query_start_values, ) logger.debug(f'extracted {len(records)} records from mysql') @@ -311,7 +316,7 @@ def perform_initial_replication_table(self, table_name): break self.clickhouse_api.insert(table_name, records, table_structure=clickhouse_table_structure) for record in records: - record_primary_key = record[primary_key_index] + record_primary_key = [record[key_idx] for key_idx in primary_key_ids] if max_primary_key is None: max_primary_key = record_primary_key else: @@ -404,6 +409,16 @@ def save_state_if_required(self, force=False): self.state.tables_last_record_version = self.clickhouse_api.tables_last_record_version self.state.save() + def _get_record_id(self, ch_table_structure, record: list): + result = [] + for idx in ch_table_structure.primary_key_ids: + field_type = ch_table_structure.fields[idx].field_type + if field_type == 'String': + result.append(f"'{record[idx]}'") + else: + result.append(record[idx]) + return ','.join(map(str, result)) + def handle_insert_event(self, event: LogEvent): if self.config.debug_log_level: logger.debug( @@ -418,12 +433,10 @@ def handle_insert_event(self, event: LogEvent): clickhouse_table_structure = self.state.tables_structure[event.table_name][1] records = self.converter.convert_records(event.records, mysql_table_structure, clickhouse_table_structure) - primary_key_ids = mysql_table_structure.primary_key_idx - current_table_records_to_insert = self.records_to_insert[event.table_name] current_table_records_to_delete = self.records_to_delete[event.table_name] for record in records: - record_id = record[primary_key_ids] + record_id = self._get_record_id(clickhouse_table_structure, record) current_table_records_to_insert[record_id] = record current_table_records_to_delete.discard(record_id) @@ -437,16 +450,9 @@ def handle_erase_event(self, event: LogEvent): self.stats.erase_events_count += 1 self.stats.erase_records_count += len(event.records) - table_structure: TableStructure = self.state.tables_structure[event.table_name][0] table_structure_ch: TableStructure = self.state.tables_structure[event.table_name][1] - primary_key_name_idx = table_structure.primary_key_idx - field_type_ch = table_structure_ch.fields[primary_key_name_idx].field_type - - if field_type_ch == 'String': - keys_to_remove = [f"'{record[primary_key_name_idx]}'" for record in event.records] - else: - keys_to_remove = [record[primary_key_name_idx] for record in event.records] + keys_to_remove = [self._get_record_id(table_structure_ch, record) for record in event.records] current_table_records_to_insert = self.records_to_insert[event.table_name] current_table_records_to_delete = self.records_to_delete[event.table_name] @@ -546,12 +552,12 @@ def upload_records(self): if not keys_to_remove: continue table_structure: TableStructure = self.state.tables_structure[table_name][0] - primary_key_name = table_structure.primary_key + primary_key_names = table_structure.primary_keys if self.config.debug_log_level: - logger.debug(f'erasing from {table_name}, primary key: {primary_key_name}, values: {keys_to_remove}') + logger.debug(f'erasing from {table_name}, primary key: {primary_key_names}, values: {keys_to_remove}') self.clickhouse_api.erase( table_name=table_name, - field_name=primary_key_name, + field_name=primary_key_names, field_values=keys_to_remove, ) diff --git a/mysql_ch_replicator/mysql_api.py b/mysql_ch_replicator/mysql_api.py index ee34b7c..2af5dbf 100644 --- a/mysql_ch_replicator/mysql_api.py +++ b/mysql_ch_replicator/mysql_api.py @@ -48,7 +48,6 @@ def create_database(self, db_name): self.cursor.execute(f'CREATE DATABASE {db_name}') def execute(self, command, commit=False): - #print(f'Executing: <{command}>') self.cursor.execute(command) if commit: self.db.commit() @@ -88,9 +87,11 @@ def get_table_create_statement(self, table_name) -> str: def get_records(self, table_name, order_by, limit, start_value=None): self.reconnect_if_required() + order_by = ','.join(order_by) where = '' if start_value is not None: - where = f'WHERE {order_by} > {start_value} ' + start_value = ','.join(map(str, start_value)) + where = f'WHERE ({order_by}) > ({start_value}) ' query = f'SELECT * FROM {table_name} {where}ORDER BY {order_by} LIMIT {limit}' self.cursor.execute(query) res = self.cursor.fetchall() diff --git a/mysql_ch_replicator/table_structure.py b/mysql_ch_replicator/table_structure.py index fc2fd26..027710e 100644 --- a/mysql_ch_replicator/table_structure.py +++ b/mysql_ch_replicator/table_structure.py @@ -9,13 +9,15 @@ class TableField: @dataclass class TableStructure: fields: list = field(default_factory=list) - primary_key: str = '' - primary_key_idx: int = 0 + primary_keys: str = '' + primary_key_ids: int = 0 table_name: str = '' def preprocess(self): field_names = [f.name for f in self.fields] - self.primary_key_idx = field_names.index(self.primary_key) + self.primary_key_ids = [ + field_names.index(key) for key in self.primary_keys + ] def add_field_after(self, new_field: TableField, after: str): diff --git a/test_mysql_ch_replicator.py b/test_mysql_ch_replicator.py index 6027257..0d5c3ff 100644 --- a/test_mysql_ch_replicator.py +++ b/test_mysql_ch_replicator.py @@ -227,7 +227,7 @@ def test_e2e_multistatement(): id int NOT NULL AUTO_INCREMENT, name varchar(255), age int, - PRIMARY KEY (id) + PRIMARY KEY (id, `name`) ); ''') @@ -259,6 +259,9 @@ def test_e2e_multistatement(): assert_wait(lambda: ch.select(TEST_TABLE_NAME, where="name='Mary'")[0].get('last_name') is None) assert_wait(lambda: ch.select(TEST_TABLE_NAME, where="name='Mary'")[0].get('city') is None) + mysql.execute(f"DELETE FROM {TEST_TABLE_NAME} WHERE name='Ivan';", commit=True) + assert_wait(lambda: len(ch.select(TEST_TABLE_NAME)) == 1) + mysql.execute( f"CREATE TABLE {TEST_TABLE_NAME_2} " f"(id int NOT NULL AUTO_INCREMENT, name varchar(255), age int, "