diff --git a/altimate_packages/altimate/utils.py b/altimate_packages/altimate/utils.py index 9c823e38a..ed00c2263 100644 --- a/altimate_packages/altimate/utils.py +++ b/altimate_packages/altimate/utils.py @@ -148,7 +148,7 @@ def sql_parse_errors(sql: str, dialect: str): def get_start_and_end_position(sql: str, invalid_string: str): start, end, num_occurences = find_single_occurrence_indices(sql, invalid_string) - if start and end: + if start is not None and end is not None: return ( list(get_line_and_column_from_position(sql, start)), list(get_line_and_column_from_position(sql, end)), diff --git a/altimate_packages/tests/test_utils.py b/altimate_packages/tests/test_utils.py new file mode 100644 index 000000000..62496d6de --- /dev/null +++ b/altimate_packages/tests/test_utils.py @@ -0,0 +1,15 @@ +import sys +import unittest +sys.path.insert(0, 'altimate_packages') +from altimate.utils import get_start_and_end_position + +class TestGetStartEndPosition(unittest.TestCase): + def test_invalid_token_at_beginning(self): + sql = "invalid_token SELECT * FROM table" + start, end, count = get_start_and_end_position(sql, "invalid_token") + self.assertEqual(start, [0, 1]) + self.assertEqual(end, [0, len("invalid_token") + 1]) + self.assertEqual(count, 1) + +if __name__ == "__main__": + unittest.main()