diff --git a/core/rule_generator.py b/core/rule_generator.py index 58a88f9..0035b5f 100644 --- a/core/rule_generator.py +++ b/core/rule_generator.py @@ -293,7 +293,7 @@ def _fingerPrint(fingerPrint: str) -> str: @staticmethod def _fingerPrint2(fingerPrint: str) -> str: # get rid of the numbers inside each var/varList - fingerPrint = re.sub(r"V(\d+)>", "V", fingerPrint) + fingerPrint = re.sub(r"V(\d+)", "V", fingerPrint) fingerPrint = re.sub(r"VL(\d+)", "VL", fingerPrint) return fingerPrint @@ -2563,11 +2563,6 @@ def variablize_all_subtrees(rule: dict, subtrees: list) -> dict: new_rule_rewrite_json = json.loads(new_rule['rewrite_json']) for subtree in subtrees: - # Prevents already variablized subtrees from being variablized again in an infinite loop - # ex. {"value": "V22"} should not be variablized again - if len(subtree) == 1 and 'value' in subtree and QueryRewriter.is_var(subtree['value']): - continue - # Find a variable name for the given subtree # new_rule_mapping, newVarInternal = RuleGenerator.findNextVarInternal(new_rule_mapping) diff --git a/tests/test_rule_generator.py b/tests/test_rule_generator.py index 2bb5fe8..27c9e33 100644 --- a/tests/test_rule_generator.py +++ b/tests/test_rule_generator.py @@ -1644,7 +1644,7 @@ def test_generate_rule_graph_0(): def unify_variable_names(q0, q1): # Variable pattern - pattern = r'<<[^>]*>>|<[^>]*>' + pattern = r'<<[^>]+>>|<[a-zA-Z][^>]*>' # Find all variables in q0 and q1, and unify their names into xi in ascending order substrings = re.findall(pattern, q0 + q1) @@ -1920,30 +1920,21 @@ def test_generate_general_rule_7(): FROM ''')) -# TODO - fix issue with becoming <> -# def test_generate_general_rule_8(): +def test_generate_general_rule_8(): -# q0 = ''' -# SELECT * FROM t WHERE CAST(created_at AS DATE) = TIMESTAMP '2016-10-01 00:00:00.000' -# ''' -# q1 = ''' -# SELECT * FROM t WHERE created_at = TIMESTAMP '2016-10-01 00:00:00.000' -# ''' + q0 = ''' + SELECT * FROM t WHERE CAST(created_at AS DATE) = TIMESTAMP '2016-10-01 00:00:00.000' + ''' + q1 = ''' + SELECT * FROM t WHERE created_at = TIMESTAMP '2016-10-01 00:00:00.000' + ''' -# rule = RuleGenerator.generate_general_rule(q0, q1) -# assert type(rule) is dict + rule = RuleGenerator.generate_general_rule(q0, q1) + assert type(rule) is dict -# assert StringUtil.strim(RuleGenerator._fingerPrint(rule['pattern'])) == StringUtil.strim(RuleGenerator._fingerPrint(''' -# CAST( AS DATE) -# ''')) or StringUtil.strim(RuleGenerator._fingerPrint(rule['rewrite'])) == StringUtil.strim(RuleGenerator._fingerPrint(''' -# CAST(<> AS DATE) -# ''')) + assert StringUtil.strim(RuleGenerator._fingerPrint(rule['pattern'])) == StringUtil.strim(RuleGenerator._fingerPrint("CAST( AS DATE)")) -# assert StringUtil.strim(RuleGenerator._fingerPrint(rule['rewrite'])) == StringUtil.strim(RuleGenerator._fingerPrint(''' -# -# ''')) or StringUtil.strim(RuleGenerator._fingerPrint(rule['rewrite'])) == StringUtil.strim(RuleGenerator._fingerPrint(''' -# <> -# ''')) + assert StringUtil.strim(RuleGenerator._fingerPrint(rule['rewrite'])) == StringUtil.strim(RuleGenerator._fingerPrint("")) def test_generate_general_rule_9(): @@ -2286,6 +2277,54 @@ def test_generate_general_rule_22(): assert q0_rule == "SELECT <>, DATE(.), CASE WHEN SUM(CASE WHEN . = THEN ELSE END) >= THEN ELSE END FROM GROUP BY <>, DATE(.)" assert q1_rule == "SELECT <>, . FROM (SELECT , DATE() FROM WHERE = ) AS t1 GROUP BY <>, ." +# TODO: fix issue with test random +# def test_generate_spreadsheet_id_2(): +# q0 = """SELECT * +# FROM place +# WHERE "select" = TRUE +# OR exists (SELECT id +# FROM bookmark +# WHERE user IN (1,2,3,4) +# AND bookmark.place = place.id) +# LIMIT 10;""" + +# q1 = """SELECT * +# FROM ( +# (SELECT * +# FROM place +# WHERE "select" = True +# LIMIT 10) +# UNION +# (SELECT * +# FROM place +# WHERE EXISTS +# (SELECT 1 +# FROM bookmark +# WHERE user IN (1, 2, 3, 4) +# AND bookmark.place = place.id) +# LIMIT 10)) +# LIMIT 10;""" + +# rule = RuleGenerator.generate_general_rule(q0, q1) +# assert type(rule) is dict + +# q0_rule, q1_rule = unify_variable_names(rule['pattern'], rule['rewrite']) +# assert q0_rule == "FROM WHERE OR EXISTS (SELECT FROM WHERE IN (, , , ) AND <>)" +# assert q1_rule == "FROM ((SELECT <> FROM WHERE LIMIT ) UNION (SELECT <> FROM WHERE EXISTS (SELECT FROM WHERE IN (, , , ) AND <>) LIMIT ))" + +def test_generate_spreadsheet_id_3(): + q0 = """SELECT EMPNO FROM EMP WHERE EMPNO > 10 AND EMPNO <= 10""" + + q1 = """SELECT EMPNO FROM EMP WHERE FALSE""" + + rule = RuleGenerator.generate_general_rule(q0, q1) + assert type(rule) is dict + + q0_rule, q1_rule = unify_variable_names(rule['pattern'], rule['rewrite']) + assert q0_rule == " > AND <= " + assert q1_rule == "False" + + def test_generate_spreadsheet_id_4(): q0 = """SELECT entities.data FROM entities WHERE entities._id IN (SELECT index_users_email._id FROM index_users_email WHERE index_users_email.key = 'test') @@ -2517,8 +2556,8 @@ def test_generate_spreadsheet_id_18(): assert type(rule) is dict q0_rule, q1_rule = unify_variable_names(rule['pattern'], rule['rewrite']) - assert q0_rule == "SELECT DISTINCT ON () , , , COALESCE(., ), <> FROM LEFT JOIN ON <> LEFT JOIN ON <> WHERE <> AND . IN (, , , , , , ) AND <> ORDER BY . DESC" - assert q1_rule == "SELECT , , , COALESCE((SELECT . FROM WHERE <> AND <> LIMIT ), ), (SELECT <> FROM WHERE <> AND . IN (, , , , , , ) LIMIT ) FROM WHERE <>" + assert StringUtil.strim(RuleGenerator._fingerPrint(q0_rule)) == StringUtil.strim(RuleGenerator._fingerPrint("SELECT DISTINCT ON () , , , COALESCE(., ), <> FROM LEFT JOIN ON <> LEFT JOIN ON <> WHERE <> AND . IN (, , , , , , ) AND <> ORDER BY . DESC")) + assert StringUtil.strim(RuleGenerator._fingerPrint(q1_rule)) == StringUtil.strim(RuleGenerator._fingerPrint("SELECT , , , COALESCE((SELECT . FROM WHERE <> AND <> LIMIT ), ), (SELECT <> FROM WHERE <> AND . IN (, , , , , , ) LIMIT ) FROM WHERE <>")) def test_generate_spreadsheet_id_20(): q0 = """SELECT * FROM (SELECT * FROM (SELECT NULL FROM EMP) WHERE N IS NULL) WHERE N IS NULL"""