diff --git a/sql_compare/__init__.py b/sql_compare/__init__.py index 7ae7080..b67c1cb 100644 --- a/sql_compare/__init__.py +++ b/sql_compare/__init__.py @@ -41,6 +41,19 @@ def compare(first_sql: str, second_sql: str) -> bool: return first_sql_statements == second_sql_statements +def get_diff( + first_sql: str, + second_sql: str, +) -> list[list[list[str]]]: + """Show the difference between two SQL schemas, ignoring differences due to column order and other non-significant SQL changes.""" + first_set = {Statement(t) for t in sqlparse.parse(first_sql)} + second_set = {Statement(t) for t in sqlparse.parse(second_sql)} + first_diffs = sorted([stmt.str_tokens for stmt in first_set - second_set]) + second_diffs = sorted([stmt.str_tokens for stmt in second_set - first_set]) + + return [first_diffs, second_diffs] + + @dataclasses.dataclass class Token: """Wrapper around `sqlparse.sql.Token`.""" @@ -74,6 +87,13 @@ def is_separator(self) -> bool: and self.token.normalized == ",", ) + @property + def str_tokens(self) -> list[str]: + """Return the token value.""" + if self.hash.strip(): + return [self.hash] + return [] + @dataclasses.dataclass class TokenList: @@ -128,6 +148,11 @@ def statement_type(self) -> str: return Statement.UNKNOWN_TYPE + @property + def str_tokens(self) -> list[str]: + """Return the reconstructed SQL statement from tokens as a list of strings.""" + return [t.hash for t in self.tokens if not t.ignore] + class Statement(TokenList): """SQL statement.""" @@ -152,6 +177,11 @@ def statement_type(self) -> str: # Only one keyword (e.g.: SELECT, INSERT, DELETE, etc.) return keywords[0] + @property + def str_tokens(self) -> list[str]: + """Return the reconstructed SQL statement from tokens as a list of strings.""" + return [t for token in self.tokens for t in token.str_tokens] + class UnorderedTokenList(TokenList): """Unordered token list.""" diff --git a/tests/test_sql_compare.py b/tests/test_sql_compare.py index d41d2e7..a334dca 100644 --- a/tests/test_sql_compare.py +++ b/tests/test_sql_compare.py @@ -184,3 +184,133 @@ def test_compare_neq(first_sql: str, second_sql: str) -> None: def test_statement_type(sql: str, expected_type: str) -> None: statement = sql_compare.Statement(sqlparse.parse(sql)[0]) assert statement.statement_type == expected_type + + +@pytest.mark.parametrize( + ("first_sql", "second_sql", "expected_diff"), + [ + ( + "CREATE TABLE foo (id INT PRIMARY KEY)", + "CREATE TABLE foo (id INT UNIQUE)", + [ + [["CREATE", "TABLE", "foo", "(", "id", "INT", "PRIMARY KEY"]], + [["CREATE", "TABLE", "foo", "(", "id", "INT", "UNIQUE"]], + ], + ), + ( + "CREATE TYPE public.colors AS ENUM ('RED', 'GREEN', 'BLUE')", + "CREATE TYPE public.colors AS ENUM ('BLUE', 'GREEN', 'RED')", + [[], []], + ), + ( + "CREATE TYPE public.colors AS ENUM ('RED', 'GREEN', 'BLUE')", + "CREATE TYPE public.colors AS ENUM ('YELLOW', 'BLUE', 'RED')", + [ + [ + [ + "CREATE", + "TYPE", + "public", + ".", + "colors", + "AS", + "ENUM", + "(", + "'BLUE'", + ",", + "'GREEN'", + ",", + "'RED'", + ], + ], + [ + [ + "CREATE", + "TYPE", + "public", + ".", + "colors", + "AS", + "ENUM", + "(", + "'BLUE'", + ",", + "'RED'", + ",", + "'YELLOW'", + ], + ], + ], + ), + ( + """ + CREATE TYPE public.status AS ENUM ('PENDING', 'APPROVED', 'REJECTED'); + CREATE TABLE users (id INT, name VARCHAR(100), status public.status); + CREATE INDEX user_status_idx ON users (status); + """, + """ + CREATE TYPE public.status AS ENUM ('PENDING', 'APPROVED', 'ARCHIVED'); + CREATE TABLE logs (id INT, message TEXT); + CREATE TABLE users (id INT, name VARCHAR(100), status public.status); + CREATE INDEX user_status_idx ON users (status); + """, + [ + [ + [ + "CREATE", + "TYPE", + "public", + ".", + "status", + "AS", + "ENUM", + "(", + "'APPROVED'", + ",", + "'PENDING'", + ",", + "'REJECTED'", + ";", + ], + ], + [ + [ + "CREATE", + "TABLE", + "logs", + "(", + "id", + "INT", + ",", + "message", + "TEXT", + ";", + ], + [ + "CREATE", + "TYPE", + "public", + ".", + "status", + "AS", + "ENUM", + "(", + "'APPROVED'", + ",", + "'ARCHIVED'", + ",", + "'PENDING'", + ";", + ], + ], + ], + ), + ], +) +def test_get_diff( + first_sql: str, + second_sql: str, + expected_diff: list[list[list[str]]], +) -> None: + result = sql_compare.get_diff(first_sql, second_sql) + assert result == expected_diff