diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml new file mode 100644 index 0000000..8eeaa5a --- /dev/null +++ b/.github/workflows/pypi.yml @@ -0,0 +1,33 @@ +name: Build and publish to pypi + +on: + workflow_run: + workflows: ["Release"] + types: + - completed + +jobs: + # Push a new release to PyPI + deploy_to_pypi: + name: Publish to PyPI + runs-on: ubuntu-latest + if: github.actor != 'mindsdbadmin' + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5.1.0 + with: + python-version: ${{ vars.CI_PYTHON_VERSION }} + - name: Install dependencies + run: | + pip install setuptools wheel twine + - name: Clean previous builds + run: rm -rf dist/ build/ *.egg-info + - name: Build and publish + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + # This uses the version string from __about__.py, which we checked matches the git tag above + python setup.py sdist + twine upload dist/* diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..00a5fc4 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,59 @@ +name: run unit tests + +on: + pull_request: + branches: + - main + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest] + python-version: [3.8,3.9,'3.10'] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements_test.txt + pip install pytest-cov + pip install --no-cache-dir -e .[test] + - name: Run unit tests + run: pytest -v + shell: bash + + + coverage: + needs: test + runs-on: ubuntu-latest + permissions: + pull-requests: write + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.8 + uses: actions/setup-python@v2 + with: + python-version: 3.8 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest pytest-cov + pip install -r requirements_test.txt + + - name: Build coverage file + run: | + pytest --junitxml=pytest.xml --cov-report=term-missing:skip-covered --cov=mindsdb_sql_parser tests/ | tee pytest-coverage.txt + + - name: Pytest coverage comment + uses: MishaKav/pytest-coverage-comment@main + with: + pytest-coverage-path: ./pytest-coverage.txt + junitxml-path: ./pytest.xml + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..300c664 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.venv +.idea +venv/ +__pycache__/ diff --git a/README.md b/README.md index 4803a01..ddda320 100644 --- a/README.md +++ b/README.md @@ -1 +1,78 @@ # MindsDB SQL Parser 🚧 + + +# Installation + +``` + pip install mindsdb_sql_parser +``` + +## How to use + +```python + +from mindsdb_sql_parser import parse_sql + +query = parse_sql('select b from aaa where c=1') + +# result is abstract syntax tree (AST) +query + +# string representation of AST +query.to_tree() + +# representation of tree as sql string. it can not exactly match with original sql +query.to_string() + +``` + +## Architecture + +For parsing is used [SLY](https://sly.readthedocs.io/en/latest/sly.html) library. + +Parsing consists of 2 stages, (separate module for every dialect): +- Defining keywords in lexer.py module. It is made mostly with regexp +- Defining syntax rules in parser.py module. It is made by describing rules in [BNF grammar](https://en.wikipedia.org/wiki/Backus%E2%80%93Naur_form) + - Syntax is defined in decorator of function. Inside of decorator you can use keyword itself or other function from parser + - Output of function can be used as input in other functions of parser + - Outputs of the parser is listed in "Top-level statements". It has to be Abstract syntax tree (AST) object. + +SLY does not support inheritance, therefore every dialect is described completely, without extension one from another. + +### [AST](https://en.wikipedia.org/wiki/Abstract_syntax_tree) +- Structure of AST is defined in separate modules (in parser/ast/). +- It can be inherited +- Every class have to have these methods: + - to_tree - to return hierarchical representation of object + - get_string - to return object as sql expression (or sub-expression) + - copy - to copy AST-tree to new object + +### Error handling + +For better user experience parsing error contains useful information about problem location and possible solution to solve it. +1. it shows location of error if + - character isn't parsed (by lexer) + - token is unexpected (by parser) +2. it tries to propose correct token instead (or before) error location. Possible options + - Keyword will be showed as is. + - '[number]' - if float and integer is expected + - '[string]' - if string is expected + - '[identifier]' - if name of the objects is expected. For example, they are bold words here: + - "select **x** as **name** from **tbl1** where **col**=1" + +How suggestion works: +It uses next possible tokens defined by syntax rules. +If this is the end of the query: just shows these tokens. +Else: +- it tries to replace bad token with other token from list of possible tokens +- tries to parse query once again, if there is no error: + - add this token to suggestion list +- second iteration: put possible token before bad token (instead of replacement) and repeat the same operation. + + +# How to test + +```bash +pip install -r requierements_test.txt +env PYTHONPATH=./ pytest +``` diff --git a/mindsdb_sql_parser/__about__.py b/mindsdb_sql_parser/__about__.py new file mode 100644 index 0000000..1ec9c61 --- /dev/null +++ b/mindsdb_sql_parser/__about__.py @@ -0,0 +1,10 @@ +__title__ = 'mindsdb_sql_parser' +__package_name__ = 'mindsdb_sql_parser' +__version__ = '0.0.1' +__description__ = "Mindsdb SQL parser" +__email__ = "jorge@mindsdb.com" +__author__ = 'MindsDB Inc' +__github__ = 'https://github.com/mindsdb/mindsdb_sql_parser' +__pypi__ = 'https://pypi.org/project/mindsdb_sql_parser' +__license__ = 'MIT' +__copyright__ = 'Copyright 2024- mindsdb' diff --git a/mindsdb_sql_parser/__init__.py b/mindsdb_sql_parser/__init__.py new file mode 100644 index 0000000..3ecc30f --- /dev/null +++ b/mindsdb_sql_parser/__init__.py @@ -0,0 +1,182 @@ +import re +from collections import defaultdict + +from sly.lex import Token + +from mindsdb_sql_parser.exceptions import ParsingException +from mindsdb_sql_parser.ast import * + + +class ErrorHandling: + + def __init__(self, lexer, parser): + self.parser = parser + self.lexer = lexer + + def process(self, error_info): + self.tokens = [t for t in error_info['tokens'] if t is not None] + self.bad_token = error_info['bad_token'] + self.expected_tokens = error_info['expected_tokens'] + + if len(self.tokens) == 0: + return 'Empty input' + + # show error location + msgs = self.error_location() + + # suggestion + suggestions = self.make_suggestion() + + if suggestions: + prefix = 'Possible inputs: ' if len(suggestions) > 1 else 'Expected symbol: ' + msgs.append(prefix + ', '.join([f'"{item}"' for item in suggestions])) + return '\n'.join(msgs) + + def error_location(self): + + # restore query text + lines_idx = defaultdict(str) + + # used + unused tokens + for token in self.tokens: + if token is None: + continue + line = lines_idx[token.lineno] + + if len(line) > token.index: + line = line[: token.index] + else: + line = line.ljust(token.index) + + line += token.value + lines_idx[token.lineno] = line + + msgs = [] + + # error message and location + if self.bad_token is None: + msgs.append('Syntax error, unexpected end of query:') + error_len = 1 + # last line + error_line_num = list(lines_idx.keys())[-1] + error_index = len(lines_idx[error_line_num]) + else: + msgs.append('Syntax error, unknown input:') + error_len = len(self.bad_token.value) + error_line_num = self.bad_token.lineno + error_index = self.bad_token.index + + # shift lines indexes (it removes spaces from beginnings of the lines) + lines = [] + shift = 0 + error_line = 0 + for i, line_num in enumerate(lines_idx.keys()): + if line_num == error_line_num: + error_index -= shift + error_line = i + + line = lines_idx[line_num] + lines.append(line[shift:]) + shift = len(line) + + # add source code + first_line = error_line - 2 if error_line > 1 else 0 + for line in lines[first_line: error_line + 1]: + msgs.append('>' + line) + + # error position + msgs.append('-' * (error_index + 1) + '^' * error_len) + return msgs + + def make_suggestion(self): + if len(self.expected_tokens) == 0: + return [] + + # find error index + error_index = None + for i, token in enumerate(self.tokens): + if token is self.bad_token : + error_index = i + + expected = {} # value: token + + for token_name in self.expected_tokens: + value = getattr(self.lexer, token_name, None) + if token_name == 'ID': + # a lot of other tokens could be ID + expected = {'[identifier]': token_name} + break + elif token_name in ('FLOAT', 'INTEGER'): + expected['[number]'] = token_name + + elif token_name in ('DQUOTE_STRING', 'QUOTE_STRING'): + expected['[string]'] = token_name + + elif isinstance(value, str): + value = value.replace('\\b', '').replace('\\', '') + + # doesn't content regexp + if '\\s' not in value and '|' not in value: + expected[value] = token_name + + suggestions = [] + if len(expected) == 1: + # use only it + first_value = list(expected.keys())[0] + suggestions.append(first_value) + + elif 1 < len(expected) < 20: + if self.bad_token is None: + # if this is the end of query, just show next expected keywords + return list(expected.keys()) + + # not every suggestion satisfy the end of the query. we have to check if it works + for value, token_name in expected.items(): + # make up a token + token = Token() + token.type = token_name + token.value = value + token.end = 0 + token.index = 0 + token.lineno = 0 + + # try to add token + tokens2 = self.tokens[:error_index] + [token] + self.tokens[error_index:] + if self.query_is_valid(tokens2): + suggestions.append(value) + continue + + # try to replace token + tokens2 = self.tokens[:error_index - 1] + [token] + self.tokens[error_index:] + if self.query_is_valid(tokens2): + suggestions.append(value) + continue + + return suggestions + + def query_is_valid(self, tokens): + # try to parse list of tokens + + ast = self.parser.parse(iter(tokens)) + return ast is not None + + +def parse_sql(sql): + from mindsdb_sql_parser.lexer import MindsDBLexer + from mindsdb_sql_parser.parser import MindsDBParser + lexer, parser = MindsDBLexer(), MindsDBParser() + + # remove ending semicolon and spaces + sql = re.sub(r'[\s;]+$', '', sql) + + tokens = lexer.tokenize(sql) + ast = parser.parse(tokens) + + if ast is None: + + eh = ErrorHandling(lexer, parser) + message = eh.process(parser.error_info) + + raise ParsingException(message) + + return ast diff --git a/mindsdb_sql_parser/ast/__init__.py b/mindsdb_sql_parser/ast/__init__.py new file mode 100644 index 0000000..0b999e1 --- /dev/null +++ b/mindsdb_sql_parser/ast/__init__.py @@ -0,0 +1,19 @@ +from .base import ASTNode +from .select import * +from .show import * +from .use import * +from .describe import * +from .set import * +from .start_transaction import * +from .rollback_transaction import * +from .commit_transaction import * +from .explain import * +from .alter_table import * +from .insert import * +from .update import * +from .delete import * +from .drop import * +from .create import * +from .variable import * + +from .mindsdb.latest import Latest diff --git a/mindsdb_sql_parser/ast/alter_table.py b/mindsdb_sql_parser/ast/alter_table.py new file mode 100644 index 0000000..3592bb8 --- /dev/null +++ b/mindsdb_sql_parser/ast/alter_table.py @@ -0,0 +1,31 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Alter(ASTNode): + ... + + +class AlterTable(ASTNode): + def __init__(self, + target, + arg, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.target = target + self.arg = arg + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + target_str = f'target={self.target.to_tree(level=level+2)}, ' + arg_str = f'arg={repr(self.arg)},' + + out_str = f'{ind}AlterTable(' \ + f'{target_str}' \ + f'{arg_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + return f'ALTER TABLE {str(self.target)} {self.arg}' + diff --git a/mindsdb_sql_parser/ast/base.py b/mindsdb_sql_parser/ast/base.py new file mode 100644 index 0000000..b67890b --- /dev/null +++ b/mindsdb_sql_parser/ast/base.py @@ -0,0 +1,52 @@ +import copy + +from mindsdb_sql_parser.exceptions import ParsingException +from mindsdb_sql_parser.utils import to_single_line + + +class ASTNode: + def __init__(self, alias=None, parentheses=False): + self.alias = alias + self.parentheses = parentheses + + if self.alias and len(self.alias.parts) > 1: + raise ParsingException('Alias can not contain multiple parts (dots).') + + def maybe_add_alias(self, some_str, alias=True): + if self.alias and alias: + return f'{some_str} AS {self.alias.to_string(alias=False)}' + else: + return some_str + + def maybe_add_parentheses(self, some_str): + if self.parentheses: + return f'({some_str})' + else: + return some_str + + def to_tree(self, *args, **kwargs): + pass + + def get_string(self): + pass + + def to_string(self, alias=True): + return self.maybe_add_alias(self.maybe_add_parentheses(self.get_string()), alias=alias) + + def copy(self): + return copy.deepcopy(self) + + def __str__(self): + return self.to_string() + + def __eq__(self, other): + if isinstance(other, ASTNode): + return self.to_tree() == other.to_tree() and to_single_line(str(self)) == to_single_line(str(other)) + else: + return False + + def __repr__(self): + sql = self.to_string().replace('\n', ' ') + if len(sql) > 500: + sql = sql[:500] + '...' + return f'{self.__class__.__name__}:<{sql}>' diff --git a/mindsdb_sql_parser/ast/commit_transaction.py b/mindsdb_sql_parser/ast/commit_transaction.py new file mode 100644 index 0000000..2bdcb48 --- /dev/null +++ b/mindsdb_sql_parser/ast/commit_transaction.py @@ -0,0 +1,16 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class CommitTransaction(ASTNode): + def __init__(self, + *args, **kwargs): + super().__init__(*args, **kwargs) + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + out_str = f'{ind}CommitTransaction()' + return out_str + + def get_string(self, *args, **kwargs): + return f'commit' diff --git a/mindsdb_sql_parser/ast/create.py b/mindsdb_sql_parser/ast/create.py new file mode 100644 index 0000000..93798a9 --- /dev/null +++ b/mindsdb_sql_parser/ast/create.py @@ -0,0 +1,115 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent +from typing import List + +try: + from sqlalchemy import types as sa_types +except ImportError: + sa_types = None + + +class TableColumn(): + def __init__(self, name, type='integer', length=None, default=None, + is_primary_key=False, nullable=None): + self.name = name + self.type = type + self.is_primary_key = is_primary_key + self.default = default + self.length = length + self.nullable = nullable + + def __eq__(self, other): + if type(self) != type(other): + return False + + for k in ['name', 'is_primary_key', 'type', 'default', 'length']: + + if getattr(self, k) != getattr(other, k): + return False + + return True + + +class CreateTable(ASTNode): + def __init__(self, + name, + from_select=None, + columns: List[TableColumn] = None, + is_replace=False, + if_not_exists=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.is_replace = is_replace + self.from_select = from_select + self.columns = columns + self.if_not_exists = if_not_exists + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level + 1) + ind2 = indent(level + 2) + + replace_str = '' + if self.is_replace: + replace_str = f'{ind1}is_replace=True\n' + + from_select_str = '' + if self.from_select is not None: + from_select_str = f'{ind1}from_select={self.from_select.to_tree(level=level+1)}\n' + + columns_str = '' + if self.columns is not None: + columns = [ + f'{ind2}{col.name}: {col.type}' + for col in self.columns + ] + + columns_str = f'{ind1}columns=\n' + '\n'.join(columns) + + out_str = f'{ind}CreateTable(\n' \ + f'{ind1}if_not_exists={self.if_not_exists},\n' \ + f'{ind1}name={self.name}\n' \ + f'{replace_str}' \ + f'{from_select_str}' \ + f'{columns_str}\n' \ + f'{ind})\n' + return out_str + + def get_string(self, *args, **kwargs): + + replace_str = '' + if self.is_replace: + replace_str = f' OR REPLACE' + + columns_str = '' + if self.columns is not None: + columns = [] + for col in self.columns: + + if not isinstance(col.type, str) and sa_types is not None: + if issubclass(col.type, sa_types.Integer): + type = 'int' + elif issubclass(col.type, sa_types.Float): + type = 'float' + elif issubclass(col.type, sa_types.Text): + type = 'text' + else: + type = str(col.type) + if col.length is not None: + type = f'{type}({col.length})' + col_str = f'{col.name} {type}' + if col.nullable is True: + col_str += ' NULL' + elif col.nullable is False: + col_str += ' NOT NULL' + columns.append(col_str) + + columns_str = '({})'.format(', '.join(columns)) + + from_select_str = '' + if self.from_select is not None: + from_select_str = self.from_select.to_string() + + name_str = str(self.name) + return f'CREATE{replace_str} TABLE {"IF NOT EXISTS " if self.if_not_exists else ""}{name_str} {columns_str} {from_select_str}' diff --git a/mindsdb_sql_parser/ast/delete.py b/mindsdb_sql_parser/ast/delete.py new file mode 100644 index 0000000..4219cfe --- /dev/null +++ b/mindsdb_sql_parser/ast/delete.py @@ -0,0 +1,32 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Delete(ASTNode): + def __init__(self, + table, + where=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.table = table + self.where = where + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level + 1) + + where_str = f'where=\n{self.where.to_tree(level=level + 2)},' if self.where else '' + + out_str = f'{ind}Delete(\n' \ + f'{ind1}table={self.table.to_tree()}\n' \ + f'{ind1}{where_str}\n' \ + f'{ind})\n' + return out_str + + def get_string(self, *args, **kwargs): + if self.where is not None: + where_str = f' WHERE {self.where.to_string()}' + else: + where_str = '' + + return f'DELETE FROM {str(self.table)}{where_str}' diff --git a/mindsdb_sql_parser/ast/describe.py b/mindsdb_sql_parser/ast/describe.py new file mode 100644 index 0000000..b634af1 --- /dev/null +++ b/mindsdb_sql_parser/ast/describe.py @@ -0,0 +1,33 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Describe(ASTNode): + def __init__(self, + value, + type=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.type = type + self.value = value + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + value_str = f'value={self.value.to_tree()}' + + type_str = '' + if self.type is not None: + type_str = f'type={self.type}, ' + + out_str = f'{ind}Describe(' \ + f'{type_str}' \ + f'{value_str}' \ + f'{ind})' + return out_str + + def get_string(self, *args, **kwargs): + type_str = '' + if self.type is not None: + type_str = f' {self.type}' + return f'DESCRIBE{type_str} {str(self.value)}' + diff --git a/mindsdb_sql_parser/ast/drop.py b/mindsdb_sql_parser/ast/drop.py new file mode 100644 index 0000000..e0dd51b --- /dev/null +++ b/mindsdb_sql_parser/ast/drop.py @@ -0,0 +1,106 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Drop(ASTNode): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def to_tree(self, *args, level=0, **kwargs): + pass + + def get_string(self, *args, **kwargs): + pass + + +class DropTables(Drop): + + def __init__(self, + tables, + if_exists=False, + only_temporary=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.tables = tables + self.if_exists = if_exists + self.only_temporary = only_temporary + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + + tables_str = ', '.join([i.to_tree() for i in self.tables]) + + out_str = f'{ind}DropTables(' \ + f'[{tables_str}], ' \ + f'if_exists={self.if_exists}, ' \ + f'only_temporary={self.only_temporary}' \ + f')' + return out_str + + def get_string(self, *args, **kwargs): + temporary_str = f'TEMPORARY' if self.only_temporary else '' + exists_str = f'IF EXISTS' if self.if_exists else '' + tables_str = ', '.join([i.to_string() for i in self.tables]) + + return f'DROP {temporary_str} TABLE {exists_str} {tables_str}' + + +class DropDatabase(Drop): + + def __init__(self, + name, + if_exists=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.if_exists = if_exists + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + name_str = f'name={self.name.to_tree()}' + + out_str = f'{ind}DropDatabase(' \ + f'{name_str}, ' \ + f'if_exists={self.if_exists}' \ + f')' + return out_str + + def get_string(self, *args, **kwargs): + exists_str = f'IF EXISTS ' if self.if_exists else '' + + return f'DROP DATABASE {exists_str}{self.name}' + + +class DropView(Drop): + + def __init__(self, + names, + if_exists=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + + # DROP VIEW removes one or more views. + # https://dev.mysql.com/doc/refman/8.0/en/drop-view.html + + self.names = names + self.if_exists = if_exists + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + names_str = ', '.join([i.to_tree() for i in self.names]) + + out_str = f'{ind}DropView(' \ + f'[{names_str}], ' \ + f'if_exists={self.if_exists}' \ + f')' + return out_str + + def get_string(self, *args, **kwargs): + exists_str = f'IF EXISTS ' if self.if_exists else '' + names_str = ', '.join(map(str, self.names)) + + return f'DROP VIEW {exists_str}{names_str}' + + + + diff --git a/mindsdb_sql_parser/ast/explain.py b/mindsdb_sql_parser/ast/explain.py new file mode 100644 index 0000000..9475756 --- /dev/null +++ b/mindsdb_sql_parser/ast/explain.py @@ -0,0 +1,23 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Explain(ASTNode): + def __init__(self, + target, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.target = target + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + target_str = f'target={self.target.to_tree(level=level+2)},' + + out_str = f'{ind}Explain(' \ + f'{target_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + return f'EXPLAIN {str(self.target)}' + diff --git a/mindsdb_sql_parser/ast/insert.py b/mindsdb_sql_parser/ast/insert.py new file mode 100644 index 0000000..9ea520c --- /dev/null +++ b/mindsdb_sql_parser/ast/insert.py @@ -0,0 +1,100 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent +from mindsdb_sql_parser.ast.create import TableColumn +from mindsdb_sql_parser.ast.select.identifier import Identifier +from mindsdb_sql_parser.ast.select.constant import Constant + +class Insert(ASTNode): + + def __init__(self, + table, + columns=None, + values=None, + from_select=None, + is_plain=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.table = table + + if columns is not None: + self.columns = [ + self.to_column(col) + for col in columns + ] + else: + self.columns = None + + # TODO require one of [values, from_select] is set + self.values = values + self.from_select = from_select + + # True if values in query are constant (without subselects and operations) + self.is_plain = is_plain + + def to_column(self, col): + if isinstance(col, str): + return TableColumn(col) + elif isinstance(col, Identifier): + return TableColumn(col.parts[0]) + elif isinstance(col, Constant): + return TableColumn(col.value) + return TableColumn(str(col)) + + def to_value(self, val): + if isinstance(val, ASTNode) : + return val.to_string() + return repr(val) + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level + 1) + ind2 = indent(level + 2) + if self.columns is not None: + columns_str = ', '.join([i.name for i in self.columns]) + else: + columns_str = '' + + if self.values is not None: + values = [] + for row in self.values: + row_str = f', '.join([self.to_value(i) for i in row]) + values.append(f'{ind2}[{row_str}]') + values_str = f'\n'.join(values) + values_str = f'{ind1}values=[\n{values_str}]\n' + else: + values_str = '' + + if self.from_select is not None: + from_select_str = f'{ind1}from_select=\n{self.from_select.to_tree(level=level+2)}\n' + else: + from_select_str = '' + + out_str = f'{ind}Insert(table={self.table.to_tree()}\n' \ + f'{ind1}columns=[{columns_str}]\n' \ + f'{values_str}' \ + f'{from_select_str}' \ + f'{ind})\n' + return out_str + + def get_string(self, *args, **kwargs): + if self.columns is not None: + cols = ', '.join([i.name for i in self.columns]) + columns_str = f'({cols})' + else: + columns_str = '' + + if self.values is not None: + values = [] + for row in self.values: + row_str = ', '.join([self.to_value(i) for i in row]) + values.append(f'({row_str})') + values_str = 'VALUES ' + ', '.join(values) + else: + values_str = '' + + if self.from_select is not None: + from_select_str = self.from_select.to_string() + else: + from_select_str = '' + + return f'INSERT INTO {str(self.table)}{columns_str} {values_str}{from_select_str}' diff --git a/mindsdb_sql_parser/ast/mindsdb/__init__.py b/mindsdb_sql_parser/ast/mindsdb/__init__.py new file mode 100644 index 0000000..7323f2f --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/__init__.py @@ -0,0 +1,23 @@ +from .agents import CreateAgent, DropAgent, UpdateAgent +from .create_view import CreateView +from .create_database import CreateDatabase +from .create_predictor import CreatePredictor, CreateAnomalyDetectionModel +from .drop_predictor import DropPredictor +from .retrain_predictor import RetrainPredictor +from .finetune_predictor import FinetunePredictor +from .drop_datasource import DropDatasource +from .drop_dataset import DropDataset +from .evaluate import Evaluate +from .latest import Latest +from .create_ml_engine import CreateMLEngine +from .drop_ml_engine import DropMLEngine +from .create_job import CreateJob +from .drop_job import DropJob +from .chatbot import CreateChatBot, UpdateChatBot, DropChatBot +from .trigger import CreateTrigger, DropTrigger +from .knowledge_base import CreateKnowledgeBase, DropKnowledgeBase +from .skills import CreateSkill, DropSkill, UpdateSkill + +# remove it in next release +CreateDatasource = CreateDatabase + diff --git a/mindsdb_sql_parser/ast/mindsdb/agents.py b/mindsdb_sql_parser/ast/mindsdb/agents.py new file mode 100644 index 0000000..2696ea2 --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/agents.py @@ -0,0 +1,94 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class CreateAgent(ASTNode): + """ + Node for creating a new agent + """ + + def __init__(self, name, model, params, if_not_exists=False, *args, **kwargs): + """ + Parameters: + name (Identifier): name of the agent to create + model (str): name of the underlying model to use with the agent + params (dict): USING parameters to create the agent with + if_not_exists (bool): if True, do not raise an error if the agent exists + """ + super().__init__(*args, **kwargs) + self.name = name + self.model = model + self.params = params + self.if_not_exists = if_not_exists + + def to_tree(self, level=0, *args, **kwargs): + ind = indent(level) + out_str = f'{ind}CreateAgent(' \ + f'if_not_exists={self.if_not_exists}' \ + f'name={self.name.to_string()}, ' \ + f'model={self.model}, ' \ + f'params={self.params})' + return out_str + + def get_string(self, *args, **kwargs): + using_ar = [f'model={repr(self.model)}'] + using_ar += [f'{k}={repr(v)}' for k, v in self.params.items()] + using_str = ', '.join(using_ar) + + out_str = f'CREATE AGENT {"IF NOT EXISTS " if self.if_not_exists else ""}{self.name.to_string()} USING {using_str}' + return out_str + + +class UpdateAgent(ASTNode): + """ + Node for updating an agent + """ + + def __init__(self, name, updated_params, *args, **kwargs): + """ + Parameters: + name (Identifier): name of the agent to update + updated_params (dict): new SET parameters of the agent to update + """ + super().__init__(*args, **kwargs) + self.name = name + self.params = updated_params + + def to_tree(self, level=0, *args, **kwargs): + ind = indent(level) + out_str = f'{ind}UpdateAgent(' \ + f'name={self.name.to_string()}, ' \ + f'updated_params={self.params})' + return out_str + + def get_string(self, *args, **kwargs): + set_ar = [f'{k}={repr(v)}' for k, v in self.params.items()] + set_str = ', '.join(set_ar) + + out_str = f'UPDATE AGENT {self.name.to_string()} SET {set_str}' + return out_str + + +class DropAgent(ASTNode): + """ + Node for dropping an agent + """ + + def __init__(self, name, if_exists=False, *args, **kwargs): + """ + Parameters: + name (Identifier): name of the agent to drop + if_exists (bool): if True, do not raise an error if the agent does not exist + """ + super().__init__(*args, **kwargs) + self.name = name + self.if_exists = if_exists + + def to_tree(self, level=0, *args, **kwargs): + ind = indent(level) + out_str = f'{ind}DropAgent(if_exists={self.if_exists}, name={self.name.to_string()})' + return out_str + + def get_string(self, *args, **kwargs): + out_str = f'DROP AGENT {"IF EXISTS " if self.if_exists else ""}{str(self.name.to_string())}' + return out_str diff --git a/mindsdb_sql_parser/ast/mindsdb/chatbot.py b/mindsdb_sql_parser/ast/mindsdb/chatbot.py new file mode 100644 index 0000000..22bf75e --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/chatbot.py @@ -0,0 +1,91 @@ +import json +import datetime as dt + +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class CreateChatBot(ASTNode): + def __init__(self, + name, + database, + model, + agent, + params=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.database=database + self.model = model + self.agent = agent + if params is None: + params = {} + self.params = params + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + model_str = self.model.to_string() if self.model else 'NULL' + agent_str = self.agent.to_string() if self.agent else 'NULL' + out_str = f'{ind}CreateChatBot(' \ + f'name={self.name.to_string()}, ' \ + f'database={self.database.to_string()}, ' \ + f'model={model_str}, ' \ + f'agent={agent_str}, ' \ + f'params={self.params})' + return out_str + + def get_string(self, *args, **kwargs): + + params = self.params.copy() + params['model'] = self.model.to_string() if self.model else 'NULL' + params['database'] = self.database.to_string() + if self.agent: + params['agent'] = self.agent.to_string() + + using_ar = [f'{k}={repr(v)}' for k, v in params.items()] + + using_str = ', '.join(using_ar) + + out_str = f'CREATE CHATBOT {self.name.to_string()} USING {using_str}' + return out_str + + +class UpdateChatBot(ASTNode): + def __init__(self, name, updated_params, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.params = updated_params + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + out_str = f'{ind}UpdateChatBot(' \ + f'name={self.name.to_string()}, ' \ + f'updated_params={self.params})' + return out_str + + def get_string(self, *args, **kwargs): + params = self.params.copy() + + set_ar = [f'{k}={repr(v)}' for k, v in params.items()] + set_str = ', '.join(set_ar) + + out_str = f'UPDATE CHATBOT {self.name.to_string()} SET {set_str}' + return out_str + + +class DropChatBot(ASTNode): + def __init__(self, + name, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + out_str = f'{ind}DropChatBot(name={self.name.to_string()})' + return out_str + + def get_string(self, *args, **kwargs): + out_str = f'DROP CHATBOT {str(self.name.to_string())}' + return out_str + diff --git a/mindsdb_sql_parser/ast/mindsdb/create_database.py b/mindsdb_sql_parser/ast/mindsdb/create_database.py new file mode 100644 index 0000000..ed33998 --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/create_database.py @@ -0,0 +1,54 @@ +import json +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class CreateDatabase(ASTNode): + def __init__(self, + name, + engine, + parameters, + is_replace=False, + if_not_exists=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.engine = engine + self.parameters = parameters + self.is_replace = is_replace + self.if_not_exists = if_not_exists + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + name_str = f'\n{ind1}name={self.name.to_string()},' + engine_str = f'\n{ind1}engine={repr(self.engine)},' + parameters_str = f'\n{ind1}parameters={str(self.parameters)},' + + replace_str = '' + if self.is_replace: + replace_str = f'\n{ind1}is_replace=True' + + out_str = f'{ind}CreateDatabase(' \ + f'\n{ind1}if_not_exists={self.if_not_exists},' \ + f'{name_str}' \ + f'{engine_str}' \ + f'{parameters_str}' \ + f'{replace_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + replace_str = '' + if self.is_replace: + replace_str = f' OR REPLACE' + + engine_str = '' + if self.engine: + engine_str = f'ENGINE = {repr(self.engine)} ' + + parameters_str = '' + if self.parameters: + parameters_str = f', PARAMETERS = {json.dumps(self.parameters)}' + out_str = f'CREATE{replace_str} DATABASE {"IF NOT EXISTS " if self.if_not_exists else ""}{self.name.to_string()} {engine_str}{parameters_str}' + return out_str diff --git a/mindsdb_sql_parser/ast/mindsdb/create_job.py b/mindsdb_sql_parser/ast/mindsdb/create_job.py new file mode 100644 index 0000000..9afde27 --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/create_job.py @@ -0,0 +1,85 @@ +import json +import datetime as dt + +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class CreateJob(ASTNode): + def __init__(self, + name, + query_str, + start_str=None, + end_str=None, + repeat_str=None, + if_query_str=None, + if_not_exists=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.query_str=query_str + self.start_str = start_str + self.end_str = end_str + self.repeat_str = repeat_str + self.date_format = '%Y-%m-%d %H:%M:%S' + self.if_not_exists = if_not_exists + self.if_query_str = if_query_str + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + name_str = f'\n{ind1}name={self.name.to_string()},' + + query_str = f'\n{ind1}query_str={repr(self.query_str)},' + + start_str = '' + if self.start_str is not None: + start_str = f'\n{ind1}start_str=\'{self.start_str}\',' + + end_str = '' + if self.end_str is not None: + end_str = f'\n{ind1}end_str=\'{self.end_str}\',' + + repeat_str = '' + if self.repeat_str is not None: + repeat_str = f'\n{ind1}repeat_str={self.repeat_str},' + + if_not_exists_str = '' + if self.if_not_exists: + if_not_exists_str = f'\n{ind1}if_not_exists=True,' + + if_query_str = '' + if self.if_query_str is not None: + if_query_str = f"\n{ind1}if_query='{self.if_query_str}'" + + out_str = f'{ind}CreateJob(' \ + f'{if_not_exists_str}' \ + f'{name_str}' \ + f'{query_str}' \ + f'{start_str}' \ + f'{end_str}' \ + f'{repeat_str}' \ + f'{if_query_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + + start_str = '' + if self.start_str is not None: + start_str = f" START '{self.start_str}'" + + end_str = '' + if self.end_str is not None: + end_str = f" END '{self.end_str}'" + + repeat_str = '' + if self.repeat_str is not None: + repeat_str = f" EVERY '{self.repeat_str}'" + + if_query_str = '' + if self.if_query_str is not None: + if_query_str = f" IF ({self.if_query_str})" + + out_str = f'CREATE JOB {"IF NOT EXISTS" if self.if_not_exists else ""} {self.name.to_string()} ({self.query_str}){start_str}{end_str}{repeat_str}{if_query_str}' + return out_str diff --git a/mindsdb_sql_parser/ast/mindsdb/create_ml_engine.py b/mindsdb_sql_parser/ast/mindsdb/create_ml_engine.py new file mode 100644 index 0000000..380707b --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/create_ml_engine.py @@ -0,0 +1,42 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class CreateMLEngine(ASTNode): + def __init__(self, + name, + handler, + params=None, + if_not_exists=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.handler = handler + self.params = params + self.if_not_exists = if_not_exists + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + + param_str = f'{repr(self.params)},' + + out_str = f'{ind}CreateMLEngine(' \ + f'\n{ind1}if_not_exists={self.if_not_exists}' \ + f'\n{ind1}name={self.name.to_tree()}' \ + f'\n{ind1}handler={self.handler}' \ + f'\n{ind1}using={param_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + using_str = '' + if self.params is not None: + using_ar = [f'{k}={repr(v)}' for k, v in self.params.items()] + + using_str = 'USING ' + ', '.join(using_ar) + + out_str = f'CREATE ML_ENGINE {"IF NOT EXISTS" if self.if_not_exists else ""} {self.name.to_string()} FROM {self.handler} {using_str}' + + return out_str.strip() + diff --git a/mindsdb_sql_parser/ast/mindsdb/create_predictor.py b/mindsdb_sql_parser/ast/mindsdb/create_predictor.py new file mode 100644 index 0000000..51e8bc6 --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/create_predictor.py @@ -0,0 +1,150 @@ +import json +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent +from mindsdb_sql_parser.ast.select import Identifier +from mindsdb_sql_parser.ast.select.operation import Object + + +class CreatePredictorBase(ASTNode): + def __init__(self, + name, + targets=None, + integration_name=None, + query_str=None, + order_by=None, + group_by=None, + window=None, + horizon=None, + using=None, + is_replace=False, + if_not_exists=False, + task=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.integration_name = integration_name + self.query_str = query_str + self.targets = targets + self.order_by = order_by + self.group_by = group_by + self.window = window + self.horizon = horizon + self.using = using + self.is_replace = is_replace + self.if_not_exists = if_not_exists + self.task = task + self._action = 'CREATE' + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + + name_str = f'\n{ind1}name={self.name.to_tree()},' + + if self.integration_name is not None: + integration_name_str = f'\n{ind1}integration_name={self.integration_name.to_tree()},' + else: + integration_name_str = 'None' + + query_str = f'\n{ind1}query={self.query_str},' + + if self.targets is not None: + target_trees = ',\n'.join([t.to_tree(level=level+2) for t in self.targets]) + targets_str = f'\n{ind1}targets=[\n{target_trees}\n{ind1}],' + else: + targets_str = '' + + group_by_str = '' + if self.group_by: + group_by_trees = ',\n'.join([t.to_tree(level=level+2) for t in self.group_by]) + group_by_str = f'\n{ind1}group_by=[\n{group_by_trees}\n{ind1}],' + + order_by_str = '' + if self.order_by: + order_by_trees = ',\n'.join([t.to_tree(level=level + 2) for t in self.order_by]) + order_by_str = f'\n{ind1}order_by=[\n{order_by_trees}\n{ind1}],' + + window_str = f'\n{ind1}window={repr(self.window)},' + horizon_str = f'\n{ind1}horizon={repr(self.horizon)},' + using_str = f'\n{ind1}using={repr(self.using)},' + + if_not_exists_str = f'\n{ind1}if_not_exists={self.if_not_exists},' if self.if_not_exists else '' + or_replace_str = f'\n{ind1}is_replace={self.is_replace},' if self.is_replace else '' + + out_str = f'{ind}{self.__class__.__name__}(' \ + f'{or_replace_str}' \ + f'{if_not_exists_str}' \ + f'{name_str}' \ + f'{integration_name_str}' \ + f'{query_str}' \ + f'{targets_str}' \ + f'{order_by_str}' \ + f'{group_by_str}' \ + f'{window_str}' \ + f'{horizon_str}' \ + f'{using_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + if self.targets is not None: + targets_str = 'PREDICT ' + ', '.join([out.to_string() for out in self.targets]) + else: + targets_str = '' + order_by_str = f'ORDER BY {", ".join([out.to_string() for out in self.order_by])} ' if self.order_by else '' + group_by_str = f'GROUP BY {", ".join([out.to_string() for out in self.group_by])} ' if self.group_by else '' + window_str = f'WINDOW {self.window} ' if self.window is not None else '' + horizon_str = f'HORIZON {self.horizon} ' if self.horizon is not None else '' + using_str = '' + if self.using: + using_ar = [] + for key, value in self.using.items(): + if isinstance(value, Object): + args = [ + f'{k}={json.dumps(v)}' + for k, v in value.params.items() + ] + args_str = ', '.join(args) + value = f'{value.type}({args_str})' + else: + value = json.dumps(value) + + using_ar.append(f'{Identifier(key).to_string()}={value}') + + using_str = f'USING ' + ', '.join(using_ar) + + query_str = '' + if self.query_str is not None: + integration_name_str = '' + if self.integration_name is not None: + integration_name_str = f' {self.integration_name.to_string()}' + + query_str = f'FROM{integration_name_str} ({self.query_str}) ' + + or_replace_str = ' OR REPLACE' if self.is_replace else '' + if_not_exists_str = 'IF NOT EXISTS ' if self.if_not_exists else '' + object_str = self._object + ' ' if self._object else '' + + out_str = f'{self._action}{or_replace_str} {object_str}{if_not_exists_str}{self.name.to_string()} {query_str}' \ + f'{targets_str} ' \ + f'{order_by_str}' \ + f'{group_by_str}' \ + f'{window_str}' \ + f'{horizon_str}' \ + f'{using_str}' + + return out_str.strip() + + +class CreatePredictor(CreatePredictorBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._object = 'MODEL' + + +# Models by task type +class CreateAnomalyDetectionModel(CreatePredictorBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._object = 'ANOMALY DETECTION MODEL' + self.task = Identifier('AnomalyDetection') diff --git a/mindsdb_sql_parser/ast/mindsdb/create_view.py b/mindsdb_sql_parser/ast/mindsdb/create_view.py new file mode 100644 index 0000000..795fcde --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/create_view.py @@ -0,0 +1,44 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent +from mindsdb_sql_parser.ast.select.identifier import Identifier + +class CreateView(ASTNode): + def __init__(self, + name, + query_str, + from_table=None, + if_not_exists=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + # todo remove it + if isinstance(name, Identifier): + name = name.to_string() + self.name = name + self.query_str = query_str + self.from_table = from_table + self.if_not_exists = if_not_exists + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + name_str = f'\n{ind1}name={repr(self.name)},' + # name_str = f'\n{ind1}name={self.name.to_string()},' + from_table_str = f'\n{ind1}from_table=\n{self.from_table.to_tree(level=level+2)},' if self.from_table else '' + query_str = f'\n{ind1}query="{self.query_str}"' + if_not_exists_str = f'\n{ind1}if_not_exists=True,' if self.if_not_exists else '' + + out_str = f'{ind}CreateView(' \ + f'{if_not_exists_str}' \ + f'{name_str}' \ + f'{query_str}' \ + f'{from_table_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + from_str = f'FROM {str(self.from_table)} ' if self.from_table else '' + # out_str = f'CREATE VIEW {self.name.to_string()} {from_str}AS ( {self.query_str} )' + out_str = f'CREATE VIEW {"IF NOT EXISTS " if self.if_not_exists else ""}{str(self.name)} {from_str}AS ( {self.query_str} )' + + return out_str + diff --git a/mindsdb_sql_parser/ast/mindsdb/drop_dataset.py b/mindsdb_sql_parser/ast/mindsdb/drop_dataset.py new file mode 100644 index 0000000..ee5ef0c --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/drop_dataset.py @@ -0,0 +1,28 @@ +from mindsdb_sql_parser.ast.drop import Drop +from mindsdb_sql_parser.utils import indent + + +class DropDataset(Drop): + def __init__(self, + name, + if_exists=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.if_exists = if_exists + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + name_str = f'\n{ind1}name={self.name.to_tree()},' + + out_str = f'{ind}DropDataset(' \ + f'{ind1}if_exists={self.if_exists},' \ + f'{name_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + out_str = f'DROP DATASET {"IF EXISTS " if self.if_exists else ""}{str(self.name)}' + return out_str + diff --git a/mindsdb_sql_parser/ast/mindsdb/drop_datasource.py b/mindsdb_sql_parser/ast/mindsdb/drop_datasource.py new file mode 100644 index 0000000..d7cb824 --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/drop_datasource.py @@ -0,0 +1,28 @@ +from mindsdb_sql_parser.ast.drop import Drop +from mindsdb_sql_parser.utils import indent + + +class DropDatasource(Drop): + def __init__(self, + name, + if_exists=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.if_exists = if_exists + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + name_str = f'\n{ind1}name={self.name.to_tree()},' + + out_str = f'{ind}DropDatasource(' \ + f'{ind1}if_exists={self.if_exists},' \ + f'{name_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + out_str = f'DROP DATASOURCE {"IF EXISTS " if self.if_exists else ""}{str(self.name)}' + return out_str + diff --git a/mindsdb_sql_parser/ast/mindsdb/drop_job.py b/mindsdb_sql_parser/ast/mindsdb/drop_job.py new file mode 100644 index 0000000..0eb89fc --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/drop_job.py @@ -0,0 +1,28 @@ +from mindsdb_sql_parser.ast.drop import Drop +from mindsdb_sql_parser.utils import indent + + +class DropJob(Drop): + def __init__(self, + name, + if_exists=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.if_exists = if_exists + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + name_str = f'\n{ind1}name={self.name.to_tree()},' + + out_str = f'{ind}DropJob(' \ + f'{ind1}if_exists={self.if_exists},' \ + f'{name_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + out_str = f'DROP JOB {"IF EXISTS " if self.if_exists else ""}{str(self.name)}' + return out_str + diff --git a/mindsdb_sql_parser/ast/mindsdb/drop_ml_engine.py b/mindsdb_sql_parser/ast/mindsdb/drop_ml_engine.py new file mode 100644 index 0000000..d3848b7 --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/drop_ml_engine.py @@ -0,0 +1,28 @@ +from mindsdb_sql_parser.ast.drop import Drop +from mindsdb_sql_parser.utils import indent + + +class DropMLEngine(Drop): + def __init__(self, + name, + if_exists=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.if_exists = if_exists + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + name_str = f'\n{ind1}name={self.name.to_tree()},' + + out_str = f'{ind}DropMLEngine(' \ + f'{ind1}if_exists={self.if_exists},' \ + f'{name_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + out_str = f'DROP ML_ENGINE {"IF EXISTS " if self.if_exists else ""}{str(self.name)}' + return out_str + diff --git a/mindsdb_sql_parser/ast/mindsdb/drop_predictor.py b/mindsdb_sql_parser/ast/mindsdb/drop_predictor.py new file mode 100644 index 0000000..a95e194 --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/drop_predictor.py @@ -0,0 +1,29 @@ +from mindsdb_sql_parser.ast.drop import Drop +from mindsdb_sql_parser.utils import indent + + +class DropPredictor(Drop): + def __init__(self, + name, + if_exists=False, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.if_exists = if_exists + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + name_str = f'\n{ind1}name={self.name.to_tree()},' + + out_str = f'{ind}DropPredictor(' \ + f'if_exists={self.if_exists}' \ + f'{name_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + exists_str = f'IF EXISTS ' if self.if_exists else '' + out_str = f'DROP PREDICTOR {exists_str}{str(self.name)}' + return out_str + diff --git a/mindsdb_sql_parser/ast/mindsdb/evaluate.py b/mindsdb_sql_parser/ast/mindsdb/evaluate.py new file mode 100644 index 0000000..e22f511 --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/evaluate.py @@ -0,0 +1,37 @@ +from mindsdb_sql_parser.utils import indent +from mindsdb_sql_parser.ast.base import ASTNode + + +class Evaluate(ASTNode): + def __init__(self, + name, + query_str, + using=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.using = using + self.query_str = query_str + self.data = None # filled-in by mindsdb, as parse_sql cannot be used at init time due to circular imports + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level + 1) + name_str = f'\n{ind1}name={self.name.to_string()},' + + query_str = f'\n{ind1}query_str={repr(self.query_str)},' + + out_str = f'{ind}Evaluate(' \ + f'{name_str}' \ + f'{query_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + inner_query_str = self.query_str + out_str = f'EVALUATE {self.name.to_string()} from ({inner_query_str})' + if self.using is not None: + using_str = ", ".join([f"{k}={v}" for k, v in self.using.items()]) + out_str = f'{out_str} USING {using_str}' + out_str += ';' + return out_str diff --git a/mindsdb_sql_parser/ast/mindsdb/finetune_predictor.py b/mindsdb_sql_parser/ast/mindsdb/finetune_predictor.py new file mode 100644 index 0000000..bc46822 --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/finetune_predictor.py @@ -0,0 +1,8 @@ +from .create_predictor import CreatePredictorBase + + +class FinetunePredictor(CreatePredictorBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._action = 'FINETUNE' + self._object = '' diff --git a/mindsdb_sql_parser/ast/mindsdb/knowledge_base.py b/mindsdb_sql_parser/ast/mindsdb/knowledge_base.py new file mode 100644 index 0000000..2f03ae6 --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/knowledge_base.py @@ -0,0 +1,107 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class CreateKnowledgeBase(ASTNode): + """ + Create a new knowledge base + """ + def __init__( + self, + name, + model=None, + storage=None, + from_select=None, + params=None, + if_not_exists=False, + *args, + **kwargs, + ): + """ + Args: + name: Identifier -- name of the knowledge base + model: Identifier -- name of the model to use + storage: Identifier -- name of the storage to use + from_select: SelectStatement -- select statement to use as the source of the knowledge base + params: dict -- additional parameters to pass to the knowledge base. E.g., chunking strategy, etc. + if_not_exists: bool -- if True, do not raise an error if the knowledge base already exists + """ + super().__init__(*args, **kwargs) + self.name = name + self.model = model + self.storage = storage + self.params = params + self.if_not_exists = if_not_exists + self.from_query = from_select + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + storage_str = f"{ind} storage={self.storage.to_string()},\n" if self.storage else "" + model_str = f"{ind} model={self.model.to_string()},\n" if self.model else "" + out_str = f""" + {ind}CreateKnowledgeBase( + {ind} if_not_exists={self.if_not_exists}, + {ind} name={self.name.to_string()}, + {ind} from_query={self.from_query.to_tree(level=level + 1) if self.from_query else None}, + {model_str}{storage_str}{ind} params={self.params} + {ind}) + """ + return out_str + + def get_string(self, *args, **kwargs): + from_query_str = ( + f"FROM ({self.from_query.get_string()})" if self.from_query else "" + ) + + using_ar = [] + if self.storage: + using_ar.append(f" STORAGE={self.storage.to_string()}") + if self.model: + using_ar.append(f" MODEL={self.model.to_string()}") + + params = self.params.copy() + if params: + using_ar += [f"{k}={repr(v)}" for k, v in params.items()] + if using_ar: + using_str = "USING " + ", ".join(using_ar) + else: + using_str = "" + + out_str = ( + f"CREATE KNOWLEDGE_BASE {'IF NOT EXISTS ' if self.if_not_exists else ''}{self.name.to_string()} " + f"{from_query_str} " + f"{using_str}" + ) + + return out_str + + def __repr__(self) -> str: + return self.to_tree() + + +class DropKnowledgeBase(ASTNode): + """ + Delete a knowledge base + """ + def __init__(self, name, if_exists=False, *args, **kwargs): + """ + Args: + name: Identifier -- name of the knowledge base + if_exists: bool -- if True, do not raise an error if the knowledge base does not exist + """ + super().__init__(*args, **kwargs) + self.name = name + self.if_exists = if_exists + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + out_str = ( + f"{ind}DropKnowledgeBase(" + f"{ind} if_exists={self.if_exists}," + f"name={self.name.to_string()})" + ) + return out_str + + def get_string(self, *args, **kwargs): + out_str = f'DROP KNOWLEDGE_BASE {"IF EXISTS " if self.if_exists else ""}{self.name.to_string()}' + return out_str diff --git a/mindsdb_sql_parser/ast/mindsdb/latest.py b/mindsdb_sql_parser/ast/mindsdb/latest.py new file mode 100644 index 0000000..aefd22e --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/latest.py @@ -0,0 +1,12 @@ +from mindsdb_sql_parser.ast.base import ASTNode + + +class Latest(ASTNode): + def __init__(self, *args, **kwargs): + super().__init__(*args, alias=None, parentheses=False, **kwargs) + + def to_tree(self, *args, level=0, **kwargs): + return '\t'*level + 'Latest()' + + def get_string(self, *args, **kwargs): + return 'LATEST' diff --git a/mindsdb_sql_parser/ast/mindsdb/retrain_predictor.py b/mindsdb_sql_parser/ast/mindsdb/retrain_predictor.py new file mode 100644 index 0000000..410a94c --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/retrain_predictor.py @@ -0,0 +1,8 @@ +from .create_predictor import CreatePredictorBase + + +class RetrainPredictor(CreatePredictorBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._action = 'RETRAIN' + self._object = '' diff --git a/mindsdb_sql_parser/ast/mindsdb/skills.py b/mindsdb_sql_parser/ast/mindsdb/skills.py new file mode 100644 index 0000000..240e12d --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/skills.py @@ -0,0 +1,94 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class CreateSkill(ASTNode): + """ + Node for creating a new skill + """ + + def __init__(self, name, type, params, if_not_exists=False, *args, **kwargs): + """ + Parameters: + name (Identifier): name of the skill to create + type (str): type of the skill to create + params (dict): USING parameters to create the skill with + if_not_exists (bool): if True, do not raise an error if the skill exists + """ + super().__init__(*args, **kwargs) + self.name = name + self.type = type + self.params = params + self.if_not_exists = if_not_exists + + def to_tree(self, level=0, *args, **kwargs): + ind = indent(level) + out_str = f'{ind}CreateSkill(' \ + f'if_not_exists={self.if_not_exists}' \ + f'name={self.name.to_string()}, ' \ + f'type={self.type}, ' \ + f'params={self.params})' + return out_str + + def get_string(self, *args, **kwargs): + using_ar = [f'type={repr(self.type)}'] + using_ar += [f'{k}={repr(v)}' for k, v in self.params.items()] + using_str = ', '.join(using_ar) + + out_str = f'CREATE SKILL {"IF NOT EXISTS " if self.if_not_exists else ""}{self.name.to_string()} USING {using_str}' + return out_str + + +class UpdateSkill(ASTNode): + """ + Node for updating a skill + """ + + def __init__(self, name, updated_params, *args, **kwargs): + """ + Parameters: + name (Identifier): name of the skill to update + updated_params (dict): new SET parameters of the skill to update + """ + super().__init__(*args, **kwargs) + self.name = name + self.params = updated_params + + def to_tree(self, level=0, *args, **kwargs): + ind = indent(level) + out_str = f'{ind}UpdateSkill(' \ + f'name={self.name.to_string()}, ' \ + f'updated_params={self.params})' + return out_str + + def get_string(self, *args, **kwargs): + set_ar = [f'{k}={repr(v)}' for k, v in self.params.items()] + set_str = ', '.join(set_ar) + + out_str = f'UPDATE SKILL {self.name.to_string()} SET {set_str}' + return out_str + + +class DropSkill(ASTNode): + """ + Node for dropping a skill + """ + + def __init__(self, name, if_exists=False, *args, **kwargs): + """ + Parameters: + name (Identifier): name of the skill to drop + if_exists (bool): if True, do not raise an error if the skill does not exist + """ + super().__init__(*args, **kwargs) + self.name = name + self.if_exists = if_exists + + def to_tree(self, level=0, *args, **kwargs): + ind = indent(level) + out_str = f'{ind}DropSkill(if_exists={self.if_exists}, name={self.name.to_string()})' + return out_str + + def get_string(self, *args, **kwargs): + out_str = f'DROP SKILL {"IF EXISTS " if self.if_exists else ""}{str(self.name.to_string())}' + return out_str diff --git a/mindsdb_sql_parser/ast/mindsdb/trigger.py b/mindsdb_sql_parser/ast/mindsdb/trigger.py new file mode 100644 index 0000000..9b9600e --- /dev/null +++ b/mindsdb_sql_parser/ast/mindsdb/trigger.py @@ -0,0 +1,72 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.ast.drop import Drop + +from mindsdb_sql_parser.utils import indent + + +class CreateTrigger(ASTNode): + def __init__(self, + name, + table, + query_str, + columns=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.table = table + self.query_str = query_str + self.columns = columns + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + name_str = f'\n{ind1}name={self.name.to_string()},' + + table_str = f'\n{ind1}table={self.table.to_string()},' + + query_str = f'\n{ind1}query_str={repr(self.query_str)},' + + columns_str = '' + if self.columns: + columns_str = ', '.join([k.to_string() for k in self.columns]) + columns_str = f'\n{ind1}columns=[{columns_str}],' + + out_str = f'{ind}CreateTrigger(' \ + f'{name_str}' \ + f'{table_str}' \ + f'{columns_str}' \ + f'{query_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + columns_str = '' + if self.columns: + columns_str = ', '.join([k.to_string() for k in self.columns]) + columns_str = f' columns {columns_str}' + + out_str = f'CREATE TRIGGER {self.name.to_string()} ON {self.table.to_string()}{columns_str} ({self.query_str})' + return out_str + + +class DropTrigger(Drop): + def __init__(self, + name, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + name_str = f'\n{ind1}name={self.name.to_tree()},' + + out_str = f'{ind}DropTrigger(' \ + f'{name_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + out_str = f'DROP TRIGGER {str(self.name)}' + return out_str + diff --git a/mindsdb_sql_parser/ast/rollback_transaction.py b/mindsdb_sql_parser/ast/rollback_transaction.py new file mode 100644 index 0000000..29b8661 --- /dev/null +++ b/mindsdb_sql_parser/ast/rollback_transaction.py @@ -0,0 +1,16 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class RollbackTransaction(ASTNode): + def __init__(self, + *args, **kwargs): + super().__init__(*args, **kwargs) + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + out_str = f'{ind}RollbackTransaction()' + return out_str + + def get_string(self, *args, **kwargs): + return f'rollback' diff --git a/mindsdb_sql_parser/ast/select/__init__.py b/mindsdb_sql_parser/ast/select/__init__.py new file mode 100644 index 0000000..7dd64b6 --- /dev/null +++ b/mindsdb_sql_parser/ast/select/__init__.py @@ -0,0 +1,16 @@ +from .select import Select +from .common_table_expression import CommonTableExpression +from .union import Union, Except, Intersect +from .constant import Constant, NullConstant, Last +from .star import Star +from .identifier import Identifier +from .join import Join +from .type_cast import TypeCast +from .tuple import Tuple +from .operation import (Operation, BinaryOperation, UnaryOperation, BetweenOperation, + Function, WindowFunction, Object, Interval, Exists, NotExists) +from .order_by import OrderBy +from .parameter import Parameter +from .case import Case +from .native_query import NativeQuery +from .data import Data diff --git a/mindsdb_sql_parser/ast/select/case.py b/mindsdb_sql_parser/ast/select/case.py new file mode 100644 index 0000000..d85b048 --- /dev/null +++ b/mindsdb_sql_parser/ast/select/case.py @@ -0,0 +1,65 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Case(ASTNode): + def __init__(self, rules, default=None, arg=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + # structure: + # [ + # [ condition, result ] + # ] + self.arg = arg + self.rules = rules + self.default = default + + def get_value(self, record): + # TODO get value from record using "case" conditions + ... + + def assert_arguments(self): + pass + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + + # rules + rules_ar = [] + for condition, result in self.rules: + rules_ar.append( + f'{ind1}{condition.to_string()} => {result.to_string()}' + ) + rules_str = '\n'.join(rules_ar) + default_str = '' + if self.default is not None: + default_str = f'{ind1}default => {self.default.to_string()}\n' + + arg_str = '' + if self.arg is not None: + arg_str = f'{ind1}arg => {self.arg.to_string()}\n' + + return f'{ind}Case(\n' \ + f'{arg_str}'\ + f'{rules_str}\n' \ + f'{default_str}' \ + f'{ind})' + + def get_string(self, *args, alias=True, **kwargs): + # rules + rules_ar = [] + for condition, result in self.rules: + rules_ar.append( + f'WHEN {condition.to_string()} THEN {result.to_string()}' + ) + rules_str = ' '.join(rules_ar) + + default_str = '' + if self.default is not None: + default_str = f' ELSE {self.default.to_string()}' + + arg_str = '' + if self.arg is not None: + arg_str = f'{self.arg.to_string()} ' + return f"CASE {arg_str}{rules_str}{default_str} END" diff --git a/mindsdb_sql_parser/ast/select/common_table_expression.py b/mindsdb_sql_parser/ast/select/common_table_expression.py new file mode 100644 index 0000000..a5f283c --- /dev/null +++ b/mindsdb_sql_parser/ast/select/common_table_expression.py @@ -0,0 +1,26 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class CommonTableExpression(ASTNode): + def __init__(self, name, query, columns=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + self.columns = columns or [] + self.query = query + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level + 1) + name_str = f'\n{ind1}name={self.name.to_tree()}' + columns_str = f'\n{ind1}columns=[{", ".join([c.to_tree() for c in self.columns])}]' + query_str = f'\n{ind1}query=\n{self.query.to_tree(level=level + 2)}' + + out_str = f'{ind}Join({name_str},{columns_str},{query_str}\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + name_str = self.name.to_string(alias=False) + columns_str = '' if not self.columns else f'( {", ".join([c.to_string(alias=False) for c in self.columns])} )' + query_str = self.query.to_string() + return f'{name_str}{columns_str} AS ( {query_str} )' diff --git a/mindsdb_sql_parser/ast/select/constant.py b/mindsdb_sql_parser/ast/select/constant.py new file mode 100644 index 0000000..69889a3 --- /dev/null +++ b/mindsdb_sql_parser/ast/select/constant.py @@ -0,0 +1,49 @@ +import datetime as dt +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Constant(ASTNode): + def __init__(self, value, with_quotes=True, *args, **kwargs): + super().__init__(*args, **kwargs) + self.value = value + self.with_quotes = with_quotes + + def to_tree(self, *args, level=0, **kwargs): + alias_str = f', alias={self.alias.to_tree()}' if self.alias else '' + return indent(level) + f'Constant(value={repr(self.value)}{alias_str})' + + def get_string(self, *args, **kwargs): + if isinstance(self.value, str) and self.with_quotes: + val = self.value.replace("'", "\\'") + out_str = f"\'{val}\'" + elif isinstance(self.value, bool): + out_str = 'TRUE' if self.value else 'FALSE' + elif isinstance(self.value, (dt.date, dt.datetime, dt.timedelta)): + out_str = "'{}'".format(str(self.value).replace("'", "''")) + else: + out_str = str(self.value) + return out_str + + +class NullConstant(Constant): + def __init__(self, *args, **kwargs): + super().__init__(value=None, *args, **kwargs) + + def to_tree(self, *args, level=0, **kwargs): + return '\t'*level + 'NullConstant()' + + def get_string(self, *args, **kwargs): + return 'NULL' + + +class Last(Constant): + def __init__(self, *args, **kwargs): + self.value = 'last' + super().__init__(self.value) + + def to_tree(self, *args, level=0, **kwargs): + return indent(level) + f'{self.__class__.__name__}()' + + def get_string(self, *args, **kwargs): + return self.value \ No newline at end of file diff --git a/mindsdb_sql_parser/ast/select/data.py b/mindsdb_sql_parser/ast/select/data.py new file mode 100644 index 0000000..e5e1c61 --- /dev/null +++ b/mindsdb_sql_parser/ast/select/data.py @@ -0,0 +1,19 @@ +from typing import List + +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Data(ASTNode): + + def __init__(self, data: List[dict], *args, **kwargs): + super().__init__(*args, **kwargs) + + self.data = data + + def to_tree(self, *args, level=0, **kwargs): + return indent(level) + \ + f'Data(len={len(self.data)})' + + def get_string(self, *args, **kwargs): + return f'"<{len(self.data)} rows>"' diff --git a/mindsdb_sql_parser/ast/select/identifier.py b/mindsdb_sql_parser/ast/select/identifier.py new file mode 100644 index 0000000..1cac18f --- /dev/null +++ b/mindsdb_sql_parser/ast/select/identifier.py @@ -0,0 +1,92 @@ +import re +from copy import copy, deepcopy + +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent +from mindsdb_sql_parser.ast.select import Star + + +no_wrap_identifier_regex = re.compile(r'[a-zA-Z_][a-zA-Z_0-9]*') +path_str_parts_regex = re.compile(r'(?:(?:(`[^`]+`))|([^.]+))') + + +def path_str_to_parts(path_str: str): + match = re.finditer(path_str_parts_regex, path_str) + parts = [x[0].strip('`') for x in match] + return parts + + +RESERVED_KEYWORDS = { + 'PERSIST', 'IF', 'EXISTS', 'NULLS', 'FIRST', 'LAST', + 'ORDER', 'BY', 'GROUP', 'PARTITION' +} + + +def get_reserved_words(): + from mindsdb_sql_parser.lexer import MindsDBLexer + + reserved = RESERVED_KEYWORDS + for word in MindsDBLexer.tokens: + if '_' not in word: + # exclude combinations + reserved.add(word) + return reserved + + +class Identifier(ASTNode): + def __init__(self, path_str=None, parts=None, *args, **kwargs): + super().__init__(*args, **kwargs) + assert path_str or parts, "Either path_str or parts must be provided for an Identifier" + assert not (path_str and parts), "Provide either path_str or parts, but not both" + if isinstance(path_str, Star) and not parts: + parts = [Star()] + + if path_str and not parts: + parts = path_str_to_parts(path_str) + assert isinstance(parts, list) + self.parts = parts + + @classmethod + def from_path_str(self, value, *args, **kwargs): + parts = path_str_to_parts(value) + return Identifier(parts=parts, *args, **kwargs) + + def parts_to_str(self): + out_parts = [] + reserved_words = get_reserved_words() + for part in self.parts: + if isinstance(part, Star): + part = str(part) + else: + if ( + not no_wrap_identifier_regex.fullmatch(part) + or + part.upper() in reserved_words + ): + part = f'`{part}`' + + out_parts.append(part) + return '.'.join(out_parts) + + def to_tree(self, *args, level=0, **kwargs): + alias_str = f', alias={self.alias.to_tree()}' if self.alias else '' + return indent(level) + f'Identifier(parts={[str(i) for i in self.parts]}{alias_str})' + + def get_string(self, *args, **kwargs): + return self.parts_to_str() + + def __copy__(self): + identifier = Identifier(parts=copy(self.parts)) + identifier.alias = deepcopy(self.alias) + identifier.parentheses = self.parentheses + if hasattr(self, 'sub_select'): + identifier.sub_select = deepcopy(self.sub_select) + return identifier + + def __deepcopy__(self, memo): + identifier = Identifier(parts=copy(self.parts)) + identifier.alias = deepcopy(self.alias) + identifier.parentheses = self.parentheses + if hasattr(self, 'sub_select'): + identifier.sub_select = deepcopy(self.sub_select) + return identifier diff --git a/mindsdb_sql_parser/ast/select/join.py b/mindsdb_sql_parser/ast/select/join.py new file mode 100644 index 0000000..d13e645 --- /dev/null +++ b/mindsdb_sql_parser/ast/select/join.py @@ -0,0 +1,30 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Join(ASTNode): + def __init__(self, join_type, left, right, condition=None, implicit=False, *args, **kwargs): + super().__init__(*args, **kwargs) + if join_type is not None: + join_type = join_type.upper() + self.join_type = join_type + self.left = left + self.right = right + self.condition = condition + self.implicit = implicit + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level + 1) + args = f'implicit={repr(self.implicit)}, join_type={repr(self.join_type)}' + left_str = f'\n{ind1}left=\n{self.left.to_tree(level=level+2)}' + right_str = f'\n{ind1}right=\n{self.right.to_tree(level=level+2)}' + condition_str = f'\n{ind1}condition=\n{self.condition.to_tree(level=level+2)}' if self.condition else '' + + out_str = f'{ind}Join({args},{left_str},{right_str},{condition_str}\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + join_type_str = f' {self.join_type} ' if not self.implicit else ', ' + condition_str = f' ON {self.condition.to_string()}' if self.condition else '' + return f'{self.left.to_string()}{join_type_str}{self.right.to_string()}{condition_str}' diff --git a/mindsdb_sql_parser/ast/select/native_query.py b/mindsdb_sql_parser/ast/select/native_query.py new file mode 100644 index 0000000..3585334 --- /dev/null +++ b/mindsdb_sql_parser/ast/select/native_query.py @@ -0,0 +1,25 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class NativeQuery(ASTNode): + """ + Not parsed query to integration + """ + + def __init__(self, integration, query: str, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.integration = integration + self.query = query + + def to_tree(self, *args, level=0, **kwargs): + return indent(level) + \ + f'NativeQuery(integration={self.integration.to_string()}, query="{self.query}")' + + def get_string(self, *args, **kwargs): + # standard native query render is used in create view + return f'{self.integration.to_string()} ({self.query})' + + def __repr__(self): + return f'{self.__class__.__name__}:{self.integration.to_string()} ({self.query})' diff --git a/mindsdb_sql_parser/ast/select/operation.py b/mindsdb_sql_parser/ast/select/operation.py new file mode 100644 index 0000000..f2491d5 --- /dev/null +++ b/mindsdb_sql_parser/ast/select/operation.py @@ -0,0 +1,206 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.exceptions import ParsingException +from mindsdb_sql_parser.utils import indent + + +class Operation(ASTNode): + def __init__(self, op, args, *args_, **kwargs): + super().__init__(*args_, **kwargs) + + self.op = ' '.join(op.lower().split()) + self.args = list(args) + self.assert_arguments() + + def assert_arguments(self): + pass + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + + arg_trees = [arg.to_tree(level=level+2) for arg in self.args] + arg_trees_str = ",\n".join(arg_trees) + out_str = f'{ind}{self.__class__.__name__}(op={repr(self.op)},\n{ind1}args=(\n{arg_trees_str}\n{ind1})\n{ind})' + return out_str + + def get_string(self, *args, alias=True, **kwargs): + arg_strs = [arg.to_string() for arg in self.args] + args_str = ','.join(arg_strs) + + return f'{self.op}({args_str})' + + +class BetweenOperation(Operation): + def __init__(self, *args, **kwargs): + super().__init__(op='between', *args, **kwargs) + + def get_string(self, *args, **kwargs): + arg_strs = [arg.to_string() for arg in self.args] + return f'{arg_strs[0]} BETWEEN {arg_strs[1]} AND {arg_strs[2]}' + + +class BinaryOperation(Operation): + def get_string(self, *args, **kwargs): + arg_strs = [] + for arg in self.args: + arg_str = arg.to_string() + # if isinstance(arg, BinaryOperation) or isinstance(arg, BetweenOperation): + # # to parens + # arg_str = f'({arg_str})' + arg_strs.append(arg_str) + + return f'{arg_strs[0]} {self.op.upper()} {arg_strs[1]}' + + def assert_arguments(self): + if len(self.args) != 2: + raise ParsingException(f'Expected two arguments for operation "{self.op}"') + + +class UnaryOperation(Operation): + def get_string(self, *args, **kwargs): + return f'{self.op} {self.args[0].to_string()}' + + def assert_arguments(self): + if len(self.args) != 1: + raise ParsingException(f'Expected one argument for operation "{self.op}"') + + +class Function(Operation): + def __init__(self, *args, distinct=False, from_arg=None, namespace=None, **kwargs): + super().__init__(*args, **kwargs) + self.distinct = distinct + self.from_arg = from_arg + self.namespace = namespace + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + + arg_trees = [arg.to_tree(level=level+2) for arg in self.args] + arg_trees_str = ",\n".join(arg_trees) + alias_str = f'alias={self.alias.to_tree()},' if self.alias else '' + from_str = f'from={self.from_arg.to_tree()}' if self.from_arg else '' + out_str = f'{ind}{self.__class__.__name__}(op={repr(self.op)}, distinct={repr(self.distinct)},{alias_str}\n' \ + f'{ind1}args=[\n' \ + f'{arg_trees_str}\n' \ + f'{ind1}]\n' \ + f'{ind1}{from_str}\n' \ + f'{ind})' + return out_str + + def get_string(self, *args, **kwargs): + args_str = ', '.join([arg.to_string() for arg in self.args]) + distinct_str = 'DISTINCT ' if self.distinct else '' + + from_str = f' FROM {self.from_arg.to_string()}' if self.from_arg else '' + namespace = self.namespace + '.' if self.namespace else '' + return f'{namespace}{self.op}({distinct_str}{args_str}{from_str})' + + +class WindowFunction(ASTNode): + def __init__(self, function, partition=None, order_by=None, alias=None, modifier=None): + super().__init__() + self.function = function + self.partition = partition + self.order_by = order_by + self.alias = alias + self.modifier = modifier + + def to_tree(self, *args, level=0, **kwargs): + fnc_str = self.function.to_tree(level=level+2) + ind = indent(level) + ind1 = indent(level+1) + partition_str = '' + if self.partition is not None: + partition_str = f',\n'.join([arg.to_tree(level=level+2) for arg in self.partition]) + partition_str = f'\n{ind1}partition=\n{partition_str}' + + order_str = '' + if self.order_by is not None: + order_str = f'\n{ind1}order_by=\n' + ',\n'.join([arg.to_tree(level=level+2) for arg in self.order_by]) + + if self.alias is not None: + alias_str = f'\n{ind1}alias=' + self.alias.to_string() + else: + alias_str = '' + return f'{ind}WindowFunction(\n' \ + f'{ind1}function=\n{fnc_str}' \ + f'{partition_str}' \ + f'{order_str}' \ + f'{alias_str}' \ + f'\n{ind})' + + def to_string(self, *args, **kwargs): + fnc_str = self.function.get_string() + partition_str = '' + if self.partition is not None: + partition_str = 'PARTITION BY ' + ', '.join([arg.to_string() for arg in self.partition]) + + order_str = '' + if self.order_by is not None: + order_str = 'ORDER BY ' + ', '.join([arg.to_string() for arg in self.order_by]) + + if self.alias is not None: + alias_str = self.alias.to_string() + else: + alias_str = '' + modifier_str = ' ' + self.modifier if self.modifier else '' + return f'{fnc_str} over({partition_str} {order_str}{modifier_str}) {alias_str}' + + +class Object(ASTNode): + def __init__(self, type, params=None, **kwargs): + super().__init__(**kwargs) + + self.type = type + self.params = params + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + + params = [ + f'{k}={v}' + for k, v in self.params.items() + ] + params_str = ': '.join(params) + + return f'{ind}Object(type={repr(self.type)}, params={{params_str}})' + + def to_string(self, *args, **kwargs): + return self.to_tree() + + def __repr__(self): + return self.to_tree() + + +class Interval(Operation): + + def __init__(self, info): + super().__init__(op='interval', args=[info, ]) + + def get_string(self, *args, **kwargs): + + arg = self.args[0] + items = arg.split(' ', maxsplit=1) + # quote first element + items[0] = f"'{items[0]}'" + return "INTERVAL " + " ".join(items) + + def to_tree(self, *args, level=0, **kwargs): + return self.get_string( *args, **kwargs) + + def assert_arguments(self): + if len(self.args) != 1: + raise ParsingException(f'Expected one argument for operation "{self.op}"') + + +class Exists(Operation): + def __init__(self, query): + self.query = query + super().__init__(op='exists', args=[query]) + + +class NotExists(Operation): + def __init__(self, query): + self.query = query + super().__init__(op='not exists', args=[query]) diff --git a/mindsdb_sql_parser/ast/select/order_by.py b/mindsdb_sql_parser/ast/select/order_by.py new file mode 100644 index 0000000..305c9b3 --- /dev/null +++ b/mindsdb_sql_parser/ast/select/order_by.py @@ -0,0 +1,21 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class OrderBy(ASTNode): + def __init__(self, field, direction='default', nulls='default', *args, **kwargs): + super().__init__(*args, **kwargs) + self.field = field + self.direction = direction + self.nulls = nulls + + def to_tree(self, *args, level=0, **kwargs): + return indent(level) + f'OrderBy(field={self.field.to_tree()}, direction={repr(self.direction)}, nulls={repr(self.nulls)})' + + def get_string(self, *args, **kwargs): + out_str = self.field.to_string() + if self.direction != 'default': + out_str += f' {self.direction}' + if self.nulls != 'default': + out_str += f' {self.nulls}' + return out_str diff --git a/mindsdb_sql_parser/ast/select/parameter.py b/mindsdb_sql_parser/ast/select/parameter.py new file mode 100644 index 0000000..893086c --- /dev/null +++ b/mindsdb_sql_parser/ast/select/parameter.py @@ -0,0 +1,16 @@ +from mindsdb_sql_parser.ast.base import ASTNode + + +class Parameter(ASTNode): + def __init__(self, value, *args, **kwargs): + super().__init__(*args, **kwargs) + self.value = value + + def __repr__(self): + return f'Parameter({self.value})' + + def to_tree(self, *args, level=0, **kwargs): + return '\t' * level + f'Parameter({repr(self.value)})' + + def get_string(self, *args, **kwargs): + return ':' + str(self.value) diff --git a/mindsdb_sql_parser/ast/select/select.py b/mindsdb_sql_parser/ast/select/select.py new file mode 100644 index 0000000..289aab6 --- /dev/null +++ b/mindsdb_sql_parser/ast/select/select.py @@ -0,0 +1,162 @@ +import json +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent +from mindsdb_sql_parser.ast.select.operation import Object + +class Select(ASTNode): + + def __init__(self, + targets, + distinct=False, + from_table=None, + where=None, + group_by=None, + having=None, + order_by=None, + limit=None, + offset=None, + cte=None, + mode=None, + modifiers=None, + using=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.targets = targets + self.distinct = distinct + self.from_table = from_table + self.where = where + self.group_by = group_by + self.having = having + self.order_by = order_by + self.limit = limit + self.offset = offset + self.cte = cte + self.mode = mode + if modifiers is None: + modifiers = [] + self.modifiers = modifiers + self.using = using + + if self.alias: + self.parentheses = True + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + + cte_str = '' + if self.cte: + cte_trees = ',\n'.join([t.to_tree(level=level + 2) for t in self.cte]) + cte_str = f'\n{ind1}cte=[\n{cte_trees}\n{ind1}],' + + alias_str = f'\n{ind1}alias={self.alias.to_tree()},' if self.alias else '' + distinct_str = f'\n{ind1}distinct={repr(self.distinct)},' if self.distinct else '' + parentheses_str = f'\n{ind1}parentheses={repr(self.parentheses)},' if self.parentheses else '' + + target_trees = ',\n'.join([t.to_tree(level=level+2) for t in self.targets]) + targets_str = f'\n{ind1}targets=[\n{target_trees}\n{ind1}],' + + from_str = f'\n{ind1}from_table=\n{self.from_table.to_tree(level=level+2)},' if self.from_table else '' + where_str = f'\n{ind1}where=\n{self.where.to_tree(level=level+2)},' if self.where else '' + + group_by_str = '' + if self.group_by: + group_by_trees = ',\n'.join([t.to_tree(level=level+2) for t in self.group_by]) + group_by_str = f'\n{ind1}group_by=[\n{group_by_trees}\n{ind1}],' + + having_str = f'\n{ind1}having=\n{self.having.to_tree(level=level+2)},' if self.having else '' + + order_by_str = '' + if self.order_by: + order_by_trees = ',\n'.join([t.to_tree(level=level + 2) for t in self.order_by]) + order_by_str = f'\n{ind1}order_by=[\n{order_by_trees}\n{ind1}],' + limit_str = f'\n{ind1}limit={self.limit.to_tree(level=0)},' if self.limit else '' + offset_str = f'\n{ind1}offset={self.offset.to_tree(level=0)},' if self.offset else '' + mode_str = f'\n{ind1}mode={self.mode},' if self.mode else '' + + using_str = '' + if self.using is not None: + using_str = f'\n{ind1}using={repr(self.using)},' + + out_str = f'{ind}Select(' \ + f'{cte_str}' \ + f'{alias_str}' \ + f'{distinct_str}' \ + f'{parentheses_str}' \ + f'{targets_str}' \ + f'{from_str}' \ + f'{where_str}' \ + f'{group_by_str}' \ + f'{having_str}' \ + f'{order_by_str}' \ + f'{limit_str}' \ + f'{offset_str}' \ + f'{mode_str}' \ + f'{using_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + + out_str = '' + if self.cte is not None: + cte_str = ', '.join([out.to_string() for out in self.cte]) + out_str += f'WITH {cte_str} ' + + out_str += "SELECT" + + if self.distinct: + out_str += ' DISTINCT' + + targets_str = ', '.join([out.to_string() for out in self.targets]) + out_str += f' {targets_str}' + + if self.from_table is not None: + from_table_str = str(self.from_table) + out_str += f' FROM {from_table_str}' + + if self.where is not None: + out_str += f' WHERE {self.where.to_string()}' + + if self.group_by is not None: + group_by_str = ', '.join([out.to_string() for out in self.group_by]) + out_str += f' GROUP BY {group_by_str}' + + if self.having is not None: + having_str = str(self.having) + out_str += f' HAVING {having_str}' + + if self.order_by is not None: + order_by_str = ', '.join([out.to_string() for out in self.order_by]) + out_str += f' ORDER BY {order_by_str}' + + if self.limit is not None: + out_str += f' LIMIT {self.limit.to_string()}' + + if self.offset is not None: + out_str += f' OFFSET {self.offset.to_string()}' + + if self.mode is not None: + out_str += f' {self.mode}' + + if self.using is not None: + from mindsdb_sql_parser.ast.select.identifier import Identifier + + using_ar = [] + for key, value in self.using.items(): + if isinstance(value, Object): + args = [ + f'{k}={json.dumps(v)}' + for k, v in value.params.items() + ] + args_str = ', '.join(args) + value = f'{value.type}({args_str})' + else: + value = json.dumps(value) + + using_ar.append(f'{Identifier(key).to_string()}={value}') + + out_str += f' USING ' + ', '.join(using_ar) + + return out_str + diff --git a/mindsdb_sql_parser/ast/select/star.py b/mindsdb_sql_parser/ast/select/star.py new file mode 100644 index 0000000..61ec509 --- /dev/null +++ b/mindsdb_sql_parser/ast/select/star.py @@ -0,0 +1,16 @@ +from mindsdb_sql_parser.ast import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Star(ASTNode): + def __init__(self, *args, **kwargs): + if 'alias' in kwargs: + from mindsdb_sql_parser.exceptions import ParsingException + raise ParsingException("Can't alias a star!") + super().__init__(*args, **kwargs) + + def to_tree(self, *args, level=0, **kwargs): + return indent(level) + f'Star()' + + def get_string(self, *args, **kwargs): + return '*' diff --git a/mindsdb_sql_parser/ast/select/tuple.py b/mindsdb_sql_parser/ast/select/tuple.py new file mode 100644 index 0000000..fb18996 --- /dev/null +++ b/mindsdb_sql_parser/ast/select/tuple.py @@ -0,0 +1,21 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Tuple(ASTNode): + def __init__(self, items, *args, **kwargs): + super().__init__(*args, **kwargs) + self.items = items + + def to_tree(self, *args, level=0, **kwargs): + item_trees = ','.join([t.to_tree(level=0) for t in self.items]) + + out_str = indent(level) + f'Tuple(items=({item_trees}))' + return out_str + + def get_string(self, *args, **kwargs): + item_strs = [] + for item in self.items: + item_strs.append(str(item)) + + return f'({", ".join(item_strs)})' diff --git a/mindsdb_sql_parser/ast/select/type_cast.py b/mindsdb_sql_parser/ast/select/type_cast.py new file mode 100644 index 0000000..7bb956a --- /dev/null +++ b/mindsdb_sql_parser/ast/select/type_cast.py @@ -0,0 +1,22 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class TypeCast(ASTNode): + def __init__(self, type_name, arg, precision=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.type_name = type_name + self.arg = arg + self.precision = precision + + def to_tree(self, *args, level=0, **kwargs): + out_str = indent(level) + f'TypeCast(type_name={repr(self.type_name)}, precision={self.precision}, arg=\n{indent(level+1)}{self.arg.to_tree()})' + return out_str + + def get_string(self, *args, **kwargs): + type_name = self.type_name + if self.precision is not None: + precision = map(str, self.precision) + type_name += f'({",".join(precision)})' + return f'CAST({str(self.arg)} AS {type_name})' diff --git a/mindsdb_sql_parser/ast/select/union.py b/mindsdb_sql_parser/ast/select/union.py new file mode 100644 index 0000000..a49ea2a --- /dev/null +++ b/mindsdb_sql_parser/ast/select/union.py @@ -0,0 +1,55 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class CombiningQuery(ASTNode): + operation = None + + def __init__(self, + left, + right, + unique=True, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.left = left + self.right = right + self.unique = unique + + if self.alias: + self.parentheses = True + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level+1) + + left_str = f'\n{ind1}left=\n{self.left.to_tree(level=level + 2)},' + right_str = f'\n{ind1}right=\n{self.right.to_tree(level=level + 2)},' + + cls_name = self.__class__.__name__ + out_str = f'{ind}{cls_name}(unique={repr(self.unique)},' \ + f'{left_str}' \ + f'{right_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + left_str = str(self.left) + right_str = str(self.right) + keyword = self.operation + if not self.unique: + keyword += ' ALL' + out_str = f"""{left_str}\n{keyword}\n{right_str}""" + + return out_str + + +class Union(CombiningQuery): + operation = 'UNION' + + +class Intersect(CombiningQuery): + operation = 'INTERSECT' + + +class Except(CombiningQuery): + operation = 'EXCEPT' diff --git a/mindsdb_sql_parser/ast/set.py b/mindsdb_sql_parser/ast/set.py new file mode 100644 index 0000000..00c3409 --- /dev/null +++ b/mindsdb_sql_parser/ast/set.py @@ -0,0 +1,130 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Set(ASTNode): + def __init__(self, + category=None, + name=None, + value=None, + scope=None, + params=None, + set_list=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + + # names / charset / transactions + self.category = category + + # name for variable assigment. category is None it this case + self.name = name + + self.value = value + self.params = params or {} + + # global / session / ... + self.scope = scope + + # contents all set subcommands + self.set_list = set_list + + + def to_tree(self, *args, level=0, **kwargs): + if self.set_list is not None: + items = [set.render() for set in self.set_list] + else: + items = self.render() + + ind = indent(level) + + return f'{ind}Set(items={items})' + + def get_string(self, *args, **kwargs): + return 'SET ' + self.render() + + def render(self): + if self.set_list is not None: + render_list = [set.render() for set in self.set_list] + return ', '.join(render_list) + + if self.params: + params = [] + for k, v in self.params.items(): + if k.lower() == 'access_mode': + params.append(v) + else: + params.append(f'{k} {v}') + param_str = ' ' + ', '.join(params) + else: + param_str = '' + + if self.name is not None: + # category should be empty + content = f'{self.name.to_string()}={self.value.to_string()}' + elif self.value is not None: + content = f'{self.category} {self.value.to_string()}' + else: + content = f'{self.category}' + + scope = '' + if self.scope is not None: + scope = f'{self.scope} ' + + return f'{scope}{content}{param_str}' + + +# class SetTransaction(ASTNode): +# def __init__(self, +# isolation_level=None, +# access_mode=None, +# scope=None, +# *args, **kwargs): +# super().__init__(*args, **kwargs) +# +# if isolation_level is not None: +# isolation_level = isolation_level.upper() +# if access_mode is not None: +# access_mode = access_mode.upper() +# if scope is not None: +# scope = scope.upper() +# +# self.scope = scope +# self.access_mode = access_mode +# self.isolation_level = isolation_level +# +# def to_tree(self, *args, level=0, **kwargs): +# ind = indent(level) +# if self.scope is None: +# scope_str = '' +# else: +# scope_str = f'scope={self.scope}, ' +# +# properties = [] +# if self.isolation_level is not None: +# properties.append('ISOLATION LEVEL ' + self.isolation_level) +# if self.access_mode is not None: +# properties.append(self.access_mode) +# prop_str = ', '.join(properties) +# +# out_str = f'{ind}SetTransaction(' \ +# f'{scope_str}' \ +# f'properties=[{prop_str}]' \ +# f'\n{ind})' +# return out_str +# +# def get_string(self, *args, **kwargs): +# properties = [] +# if self.isolation_level is not None: +# properties.append('ISOLATION LEVEL ' + self.isolation_level) +# if self.access_mode is not None: +# properties.append(self.access_mode) +# +# prop_str = ', '.join(properties) +# +# if self.scope is None: +# scope_str = '' +# else: +# scope_str = self.scope + ' ' +# +# return f'SET {scope_str}TRANSACTION {prop_str}' + diff --git a/mindsdb_sql_parser/ast/show.py b/mindsdb_sql_parser/ast/show.py new file mode 100644 index 0000000..e25294a --- /dev/null +++ b/mindsdb_sql_parser/ast/show.py @@ -0,0 +1,91 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Show(ASTNode): + def __init__(self, + category, + modes=None, + from_table=None, + in_table=None, + where=None, + name=None, + like=None, + *args_, **kwargs): + super().__init__(*args_, **kwargs) + + if category == 'SLAVE HOSTS': + category = 'REPLICAS' + + self.category = category.upper() + self.modes = modes + self.where = where + self.from_table = from_table + self.in_table = in_table + self.like = like + self.name = name + + + def to_tree(self, *args, level=0, **kwargs): + + ind = indent(level) + ind1 = indent(level+1) + category_str = f'{ind1}category={repr(self.category)},' + from_str = f'\n{ind1}from={self.from_table.to_string()},' if self.from_table else '' + in_str = f'\n{ind1}in={self.in_table.to_tree(level=level + 2)},' if self.in_table else '' + where_str = f'\n{ind1}where=\n{self.where.to_tree(level=level+2)},' if self.where else '' + name_str = f'\n{ind1}name={self.name},' if self.name else '' + like_str = f'\n{ind1}like={self.like},' if self.like else '' + modes_str = f'\n{ind1}modes=[{",".join(self.modes)}],' if self.modes else '' + out_str = f'{ind}Show(' \ + f'{category_str}' \ + f'{name_str}' \ + f'{modes_str}' \ + f'{from_str}' \ + f'{in_str}' \ + f'{like_str}' \ + f'{where_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + + from_str = '' + if self.from_table: + ar = [ + f'FROM {i}' + for i in self.from_table.parts + ] + ar.reverse() + from_str = ' ' + ' '.join(ar) + + in_str = '' + if self.in_table: + ar = [ + f'IN {i}' + for i in self.in_table.parts + ] + ar.reverse() + in_str = ' ' + ' '.join(ar) + + modes_str = f' {" ".join(self.modes)}' if self.modes else '' + like_str = f" LIKE '{self.like}'" if self.like else "" + where_str = f' WHERE {str(self.where)}' if self.where else '' + + # custom commands + if self.category in ('FUNCTION CODE', 'PROCEDURE CODE', 'ENGINE') or self.category.startswith('ENGINE '): + return f'SHOW {self.category} {self.name}' + elif self.category == 'REPLICA STATUS': + channel = '' + if self.name is not None: + channel = f' FOR CHANNEL {self.name}' + return f'SHOW {self.category} {channel}' + + return f'SHOW{modes_str} {self.category}{from_str}{in_str}{like_str}{where_str}' + + + + + + + diff --git a/mindsdb_sql_parser/ast/start_transaction.py b/mindsdb_sql_parser/ast/start_transaction.py new file mode 100644 index 0000000..bc171e7 --- /dev/null +++ b/mindsdb_sql_parser/ast/start_transaction.py @@ -0,0 +1,16 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class StartTransaction(ASTNode): + def __init__(self, + *args, **kwargs): + super().__init__(*args, **kwargs) + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + out_str = f'{ind}StartTransaction()' + return out_str + + def get_string(self, *args, **kwargs): + return f'start transaction' diff --git a/mindsdb_sql_parser/ast/update.py b/mindsdb_sql_parser/ast/update.py new file mode 100644 index 0000000..212d3dc --- /dev/null +++ b/mindsdb_sql_parser/ast/update.py @@ -0,0 +1,90 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Update(ASTNode): + def __init__(self, + table, + update_columns=None, + keys=None, + from_select=None, + from_select_alias=None, + where=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + + self.table = table + # list[Identifier] + self.keys = keys + # dict: {str: Identifier} + self.update_columns = update_columns + self.where = where + self.from_select = from_select + self.from_select_alias = from_select_alias + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + ind1 = indent(level + 1) + + updated_str = '' + if self.update_columns is not None: + updated_ar = [ + f'{k}={v.to_string()}' + for k, v in self.update_columns.items() + ] + updated_str = ', '.join(updated_ar) + updated_str = f'{ind1}update_columns={updated_str}\n' + + keys_str = '' + if self.keys is not None: + keys_ar = [k.to_string() for k in self.keys] + keys_str = ', '.join(keys_ar) + keys_str = f'{ind1}keys={keys_str}\n' + + where_str = '' + if self.where is not None: + where_str = ind1 + self.where.to_tree() + + if self.from_select is not None: + from_select_str = f'{ind1}from_select=\n{self.from_select.to_tree(level=level+2)}\n' + if self.from_select_alias is not None: + from_select_str += f'{ind1}from_select_alias=\n{self.from_select_alias.to_tree(level=level+2)}\n' + + else: + from_select_str = '' + + out_str = f'{ind}Update(table={self.table.to_tree()}\n' \ + f'{keys_str}' \ + f'{updated_str}' \ + f'{where_str}' \ + f'{from_select_str}' \ + f'{ind})\n' + return out_str + + def get_string(self, *args, **kwargs): + update_str = '' + if self.update_columns is not None: + update_ar = [ + f'{k}={v.to_string()}' + for k, v in self.update_columns.items() + ] + update_str = ' set ' + ', '.join(update_ar) + + keys_str = '' + if self.keys is not None: + keys_ar = [k.to_string() for k in self.keys] + keys_str = ' on ' + ', '.join(keys_ar) + + if self.from_select is not None: + alias_str = '' + if self.from_select_alias is not None: + alias_str = ' as ' + self.from_select_alias.to_string() + from_select_str = f' from ({self.from_select.to_string()}){alias_str}' + else: + from_select_str = '' + + where_str = '' + if self.where is not None: + where_str = ' where ' + self.where.to_string() + + return f'update {self.table.to_string()}{keys_str}{update_str}{from_select_str}{where_str}' diff --git a/mindsdb_sql_parser/ast/use.py b/mindsdb_sql_parser/ast/use.py new file mode 100644 index 0000000..31e65f9 --- /dev/null +++ b/mindsdb_sql_parser/ast/use.py @@ -0,0 +1,23 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Use(ASTNode): + def __init__(self, + value, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.value = value + + def to_tree(self, *args, level=0, **kwargs): + ind = indent(level) + value_str = f'value={self.value.to_tree(level=level+2)},' + + out_str = f'{ind}Use(' \ + f'{value_str}' \ + f'\n{ind})' + return out_str + + def get_string(self, *args, **kwargs): + return f'USE {str(self.value)}' + diff --git a/mindsdb_sql_parser/ast/variable.py b/mindsdb_sql_parser/ast/variable.py new file mode 100644 index 0000000..134789d --- /dev/null +++ b/mindsdb_sql_parser/ast/variable.py @@ -0,0 +1,17 @@ +from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb_sql_parser.utils import indent + + +class Variable(ASTNode): + def __init__(self, value, is_system_var=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.value = value + self.is_system_var = is_system_var + + def to_tree(self, *args, level=0, **kwargs): + alias_str = f', alias={self.alias.to_tree()}' if self.alias else '' + return indent(level) + f'Variable(value={repr(self.value)}{alias_str}, is_system_var={repr(self.is_system_var)})' + + def get_string(self, *args, **kwargs): + return ('@@' if self.is_system_var else '@') + f'{str(self.value)}' + diff --git a/mindsdb_sql_parser/exceptions.py b/mindsdb_sql_parser/exceptions.py new file mode 100644 index 0000000..e61ac72 --- /dev/null +++ b/mindsdb_sql_parser/exceptions.py @@ -0,0 +1,7 @@ +class MindsdbSQLException(Exception): + pass + + +class ParsingException(MindsdbSQLException): + pass + diff --git a/mindsdb_sql_parser/lexer.py b/mindsdb_sql_parser/lexer.py new file mode 100644 index 0000000..9232efb --- /dev/null +++ b/mindsdb_sql_parser/lexer.py @@ -0,0 +1,389 @@ +import re +from sly import Lexer +from sly.lex import LexError + +""" +Unfortunately we can't inherit from base SQLLexer, because the order of rules is important. +If we do, like in MySQL lexer, the new rules like `DATASOURCE = r'\bDATASOURCE\b'` are added to the end of the rule list. +Then, for an input `DATASOURCE`, the last matched regexp is `STRING`, and the token is incorrectly classified +as a string. +""" +class MindsDBLexer(Lexer): + reflags = re.IGNORECASE + ignore = ' \t\r' + ignore_multi_comment = r'/\*[\s\S]*?\*/' + ignore_line_comment = r'--[^\n]*' + + tokens = { + USE, DROP, CREATE, DESCRIBE, RETRAIN, REPLACE, + + # Misc + SET, START, TRANSACTION, COMMIT, ROLLBACK, ALTER, EXPLAIN, + ISOLATION, LEVEL, REPEATABLE, READ, WRITE, UNCOMMITTED, COMMITTED, + SERIALIZABLE, ONLY, CONVERT, BEGIN, + + # Mindsdb special + + PREDICTOR, PREDICTORS, DATASOURCE, INTEGRATION, INTEGRATIONS,DATASOURCES, + STREAM, STREAMS, PUBLICATION, PUBLICATIONS, VIEW, VIEWS, DATASETS, DATASET, + MODEL, MODELS, ML_ENGINE, ML_ENGINES, HANDLERS, + FINETUNE, EVALUATE, + LATEST, LAST, HORIZON, USING, + ENGINE, TRAIN, PREDICT, PARAMETERS, JOB, CHATBOT, EVERY,PROJECT, + ANOMALY, DETECTION, + KNOWLEDGE_BASE, KNOWLEDGE_BASES, + SKILL, + AGENT, + + # SHOW/DDL Keywords + + SHOW, SCHEMAS, SCHEMA, DATABASES, DATABASE, TABLES, TABLE, FULL, EXTENDED, PROCESSLIST, + MUTEX, CODE, SLAVE, REPLICA, REPLICAS, CHANNEL, TRIGGERS, TRIGGER, KEYS, STORAGE, LOGS, BINARY, + MASTER, PRIVILEGES, PROFILES, HOSTS, OPEN, INDEXES, + VARIABLES, SESSION, STATUS, PRIMARY_KEY, DEFAULT, + GLOBAL, PROCEDURE, FUNCTION, INDEX, WARNINGS, + ENGINES, CHARSET, COLLATION, PLUGINS, CHARACTER, + PERSIST, PERSIST_ONLY, + EXISTS, NOT_EXISTS, IF, COLUMNS, FIELDS, COLLATE, SEARCH_PATH, + VARIABLE, SYSTEM_VARIABLE, + + # SELECT Keywords + WITH, SELECT, DISTINCT, FROM, WHERE, AS, + LIMIT, OFFSET, ASC, DESC, NULLS_FIRST, NULLS_LAST, + GROUP_BY, HAVING, ORDER_BY, + STAR, FOR, UPDATE, + + JOIN, INNER, OUTER, CROSS, LEFT, RIGHT, ON, + + UNION, ALL, INTERSECT, EXCEPT, + + # CASE + CASE, ELSE, END, THEN, WHEN, + + # DML + INSERT, DELETE, INTO, VALUES, + + # Special + DOT, COMMA, LPAREN, RPAREN, PARAMETER, + # json + LBRACE, RBRACE, LBRACKET, RBRACKET, COLON, SEMICOLON, + + # Operators + PLUS, MINUS, MATCH, NOT_MATCH, DIVIDE, MODULO, + EQUALS, NEQUALS, GREATER, GEQ, LESS, LEQ, + AND, OR, NOT, IS, IS_NOT, TYPECAST, + IN, NOT_IN, LIKE, NOT_LIKE, CONCAT, BETWEEN, WINDOW, OVER, PARTITION_BY, + JSON_GET, JSON_GET_STR, INTERVAL, + + # Data types + CAST, ID, INTEGER, DATE, FLOAT, QUOTE_STRING, DQUOTE_STRING, NULL, TRUE, FALSE, + + } + + RETRAIN = r'\bRETRAIN\b' + FINETUNE = r'\bFINETUNE\b' + # Custom commands + + USE = r'\bUSE\b' + ENGINE = r'\bENGINE\b' + TRAIN = r'\bTRAIN\b' + PREDICT = r'\bPREDICT\b' + DROP = r'\bDROP\b' + PARAMETERS = r'\bPARAMETERS\b' + HORIZON = r'\bHORIZON\b' + USING = r'\bUSING\b' + VIEW = r'\bVIEW\b' + VIEWS = r'\bVIEWS\b' + STREAM = r'\bSTREAM\b' + STREAMS = r'\bSTREAMS\b' + PREDICTOR = r'\bPREDICTOR\b' + PREDICTORS = r'\bPREDICTORS\b' + DATASOURCE = r'\bDATASOURCE\b' + INTEGRATION = r'\bINTEGRATION\b' + INTEGRATIONS = r'\bINTEGRATIONS\b' + DATASOURCES = r'\bDATASOURCES\b' + PUBLICATION = r'\bPUBLICATION\b' + PUBLICATIONS = r'\bPUBLICATIONS\b' + DATASETS = r'\bDATASETS\b' + DATASET = r'\bDATASET\b' + LATEST = r'\bLATEST\b' + LAST = r'\bLAST\b' + MODEL = r'\bMODEL\b' + MODELS = r'\bMODELS\b' + ML_ENGINE = r'\bML_ENGINE\b' + ML_ENGINES = r'\bML_ENGINES\b' + HANDLERS = r'\bHANDLERS\b' + JOB = r'\bJOB\b' + CHATBOT = r'\bCHATBOT\b' + EVERY = r'\bEVERY\b' + PROJECT = r'\bPROJECT\b' + EVALUATE = r'\bEVALUATE\b' + + # Typed models + ANOMALY = r'\bANOMALY\b' + DETECTION = r'\bDETECTION\b' + + KNOWLEDGE_BASE = r'\bKNOWLEDGE[_|\s]BASE\b' + KNOWLEDGE_BASES = r'\bKNOWLEDGE[_|\s]BASES\b' + SKILL = r'\bSKILL\b' + AGENT = r'\bAGENT\b' + + # Misc + SET = r'\bSET\b' + START = r'\bSTART\b' + TRANSACTION = r'\bTRANSACTION\b' + COMMIT = r'\bCOMMIT\b' + ROLLBACK = r'\bROLLBACK\b' + EXPLAIN = r'\bEXPLAIN\b' + ALTER = r'\bALTER\b' + ISOLATION = r'\bISOLATION\b' + LEVEL = r'\bLEVEL\b' + REPEATABLE = r'\bREPEATABLE\b' + READ = r'\bREAD\b' + WRITE = r'\bWRITE\b' + UNCOMMITTED = r'\bUNCOMMITTED\b' + COMMITTED = r'\bCOMMITTED\b' + SERIALIZABLE = r'\bSERIALIZABLE\b' + ONLY = r'\bONLY\b' + CONVERT = r'\bCONVERT\b' + DESCRIBE = r'\bDESCRIBE\b' + BEGIN = r'\bBEGIN\b' + DATE = r'\bDATE\b' + + # SHOW + SHOW = r'\bSHOW\b' + SCHEMAS = r'\bSCHEMAS\b' + SCHEMA = r'\bSCHEMA\b' + DATABASES = r'\bDATABASES\b' + DATABASE = r'\bDATABASE\b' + TABLES = r'\bTABLES\b' + TABLE = r'\bTABLE\b' + FULL = r'\bFULL\b' + VARIABLES = r'\bVARIABLES\b' + SESSION = r'\bSESSION\b' + STATUS = r'\bSTATUS\b' + GLOBAL = r'\bGLOBAL\b' + PROCEDURE = r'\bPROCEDURE\b' + PRIMARY_KEY = r'\bPRIMARY[_|\s]KEY\b' + DEFAULT = r'\bDEFAULT\b' + FUNCTION = r'\bFUNCTION\b' + INDEX = r'\bINDEX\b' + CREATE = r'\bCREATE\b' + WARNINGS = r'\bWARNINGS\b' + ENGINES = r'\bENGINES\b' + CHARSET = r'\bCHARSET\b' + CHARACTER = r'\bCHARACTER\b' + COLLATION = r'\bCOLLATION\b' + PLUGINS = r'\bPLUGINS\b' + PERSIST = r'\bPERSIST\b' + PERSIST_ONLY = r'\bPERSIST_ONLY\b' + EXISTS = r'\bEXISTS\b' + NOT_EXISTS = r'\bNOT[\s]+EXISTS\b' + IF = r'\bIF\b' + COLUMNS = r'\bCOLUMNS\b' + FIELDS = r'\bFIELDS\b' + EXTENDED = r'\bEXTENDED\b' + PROCESSLIST = r'\bPROCESSLIST\b' + MUTEX = r'\bMUTEX\b' + CODE = r'\bCODE\b' + SLAVE = r'\bSLAVE\b' + REPLICA = r'\bREPLICA\b' + REPLICAS = r'\bREPLICAS\b' + CHANNEL = r'\bCHANNEL\b' + TRIGGERS = r'\bTRIGGERS\b' + TRIGGER = r'\bTRIGGER\b' + KEYS = r'\bKEYS\b' + STORAGE = r'\bSTORAGE\b' + LOGS = r'\bLOGS\b' + BINARY = r'\bBINARY\b' + MASTER = r'\bMASTER\b' + PRIVILEGES = r'\bPRIVILEGES\b' + PROFILES = r'\bPROFILES\b' + HOSTS = r'\bHOSTS\b' + OPEN = r'\bOPEN\b' + INDEXES = r'\bINDEXES\b' + REPLACE = r'\bREPLACE\b' + COLLATE = r'\bCOLLATE\b' + SEARCH_PATH = r'\bSEARCH_PATH\b' + + # SELECT + + ON = r'\bON\b' + ASC = r'\bASC\b' + DESC = r'\bDESC\b' + NULLS_FIRST = r'\bNULLS FIRST\b' + NULLS_LAST = r'\bNULLS LAST\b' + WITH = r'\bWITH\b' + SELECT = r'\bSELECT\b' + DISTINCT = r'\bDISTINCT\b' + FROM = r'\bFROM\b' + AS = r'\bAS\b' + WHERE = r'\bWHERE\b' + LIMIT = r'\bLIMIT\b' + OFFSET = r'\bOFFSET\b' + GROUP_BY = r'\bGROUP BY\b' + HAVING = r'\bHAVING\b' + ORDER_BY = r'\bORDER BY\b' + STAR = r'\*' + FOR = r'\bFOR\b' + UPDATE = r'\bUPDATE\b' + + JOIN = r'\bJOIN\b' + INNER = r'\bINNER\b' + OUTER = r'\bOUTER\b' + CROSS = r'\bCROSS\b' + LEFT = r'\bLEFT\b' + RIGHT = r'\bRIGHT\b' + + # UNION + + UNION = r'\bUNION\b' + INTERSECT = r'\bINTERSECT\b' + EXCEPT = r'\bEXCEPT\b' + ALL = r'\bALL\b' + + # CASE + CASE = r'\bCASE\b' + ELSE = r'\bELSE\b' + END = r'\bEND\b' + THEN = r'\bTHEN\b' + WHEN = r'\bWHEN\b' + + # DML + INSERT = r'\bINSERT\b' + DELETE = r'\bDELETE\b' + INTO = r'\bINTO\b' + VALUES = r'\bVALUES\b' + + # Special + TYPECAST = r'\:\:' + DOT = r'\.' + COMMA = r',' + LPAREN = r'\(' + RPAREN = r'\)' + PARAMETER = r'\?' + # json + LBRACE = r'\{' + RBRACE = r'\}' + LBRACKET = r'\[' + RBRACKET = r'\]' + COLON = r'\:' + SEMICOLON = r'\;' + + # Operators + JSON_GET = r'->' + JSON_GET_STR = r'->>' + PLUS = r'\+' + MINUS = r'-' + MATCH = r'~' + NOT_MATCH = r'!~' + DIVIDE = r'/' + MODULO = r'%' + EQUALS = r'=' + NEQUALS = r'(!=|<>)' + GEQ = r'>=' + GREATER = r'>' + LEQ = r'<=' + LESS = r'<' + AND = r'\bAND\b' + OR = r'\bOR\b' + IS_NOT = r'\bIS[\s]+NOT\b' + NOT_LIKE = r'\bNOT[\s]+LIKE\b' + NOT_IN = r'\bNOT[\s]+IN\b' + NOT = r'\bNOT\b' + IS = r'\bIS\b' + LIKE = r'\bLIKE\b' + IN = r'\bIN\b' + CAST = r'\bCAST\b' + CONCAT = r'\|\|' + BETWEEN = r'\bBETWEEN\b' + INTERVAL = r'\bINTERVAL\b' + WINDOW = r'\bWINDOW\b' + OVER = r'\bOVER\b' + PARTITION_BY = r'\bPARTITION BY\b' + + # Data types + NULL = r'\bNULL\b' + TRUE = r'\bTRUE\b' + FALSE = r'\bFALSE\b' + + @_(r'(?:([a-zA-Z_$0-9]*[a-zA-Z_$]+[a-zA-Z_$0-9]*)|(?:`([^`]+)`))') + def ID(self, t): + return t + + @_(r'\d+\.\d+') + def FLOAT(self, t): + return t + + @_(r'\d+') + def INTEGER(self, t): + return t + + @_(r"'(?:\\.|[^'])*(?:''(?:\\.|[^'])*)*'") + def QUOTE_STRING(self, t): + t.value = t.value.replace('\\"', '"').replace("\\'", "'").replace("''", "'") + return t + + @_(r'"(?:\\.|[^"])*"') + def DQUOTE_STRING(self, t): + t.value = t.value.replace('\\"', '"').replace("\\'", "'") + return t + + @_(r'\n+') + def ignore_newline(self, t): + self.lineno += len(t.value) + + @_(r'@[a-zA-Z_.$]+', + r"@'[a-zA-Z_.$][^']*'", + r"@`[a-zA-Z_.$][^`]*`", + r'@"[a-zA-Z_.$][^"]*"' + ) + def VARIABLE(self, t): + t.value = t.value.lstrip('@') + + if t.value[0] == '"': + t.value = t.value.strip('\"') + elif t.value[0] == "'": + t.value = t.value.strip('\'') + elif t.value[0] == "`": + t.value = t.value.strip('`') + return t + + @_(r'@@[a-zA-Z_.$]+', + r"@@'[a-zA-Z_.$][^']*'", + r"@@`[a-zA-Z_.$][^`]*`", + r'@@"[a-zA-Z_.$][^"]*"' + ) + def SYSTEM_VARIABLE(self, t): + t.value = t.value.lstrip('@') + + if t.value[0] == '"': + t.value = t.value.strip('\"') + elif t.value[0] == "'": + t.value = t.value.strip('\'') + elif t.value[0] == "`": + t.value = t.value.strip('`') + return t + + def error(self, t): + + # convert to lines + lines = [] + shift = 0 + error_line = 0 + error_index = 0 + for i, line in enumerate(self.text.split('\n')): + if 0 <= t.index - shift < len(line): + error_line = i + error_index = t.index - shift + lines.append(line) + shift += len(line) + 1 + + msgs = [f'Illegal character {t.value[0]!r}:'] + # show error code + for line in lines[error_line - 1: error_line + 1]: + msgs.append('>' + line) + + msgs.append('-' * (error_index + 1) + '^') + + raise LexError('\n'.join(msgs), t.value, self.index) diff --git a/mindsdb_sql_parser/logger.py b/mindsdb_sql_parser/logger.py new file mode 100644 index 0000000..5e3adae --- /dev/null +++ b/mindsdb_sql_parser/logger.py @@ -0,0 +1,26 @@ +import logging +import os +import sys +from sly.yacc import SlyLogger + + +class ParserLogger(SlyLogger): + def __init__(self, f=sys.stderr, logger=None): + super().__init__(f) + self.logger = logging.getLogger(__name__) + self.logger.setLevel(os.environ.get('MINDSDB_SQL_LOGLEVEL', 'ERROR')) + + def debug(self, msg, *args, **kwargs): + self.logger.debug(msg, *args) + + def info(self, msg, *args, **kwargs): + self.logger.info(msg, *args) + + def warning(self, msg, *args, **kwargs): + self.logger.warning(msg, *args) + + def error(self, msg, *args, **kwargs): + self.logger.error(msg, *args) + + def critical(self, msg, *args, **kwargs): + self.logger.critical(msg, *args) diff --git a/mindsdb_sql_parser/parser.py b/mindsdb_sql_parser/parser.py new file mode 100644 index 0000000..44aa2dc --- /dev/null +++ b/mindsdb_sql_parser/parser.py @@ -0,0 +1,1953 @@ +from sly import Parser +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.ast.drop import DropDatabase, DropView +from mindsdb_sql_parser.ast.mindsdb.agents import CreateAgent, DropAgent, UpdateAgent +from mindsdb_sql_parser.ast.mindsdb.drop_datasource import DropDatasource +from mindsdb_sql_parser.ast.mindsdb.drop_predictor import DropPredictor +from mindsdb_sql_parser.ast.mindsdb.drop_dataset import DropDataset +from mindsdb_sql_parser.ast.mindsdb.drop_ml_engine import DropMLEngine +from mindsdb_sql_parser.ast.mindsdb.create_predictor import CreatePredictor, CreateAnomalyDetectionModel +from mindsdb_sql_parser.ast.mindsdb.create_database import CreateDatabase +from mindsdb_sql_parser.ast.mindsdb.create_ml_engine import CreateMLEngine +from mindsdb_sql_parser.ast.mindsdb.create_view import CreateView +from mindsdb_sql_parser.ast.mindsdb.create_job import CreateJob +from mindsdb_sql_parser.ast.mindsdb.chatbot import CreateChatBot, UpdateChatBot, DropChatBot +from mindsdb_sql_parser.ast.mindsdb.drop_job import DropJob +from mindsdb_sql_parser.ast.mindsdb.trigger import CreateTrigger, DropTrigger +from mindsdb_sql_parser.ast.mindsdb.latest import Latest +from mindsdb_sql_parser.ast.mindsdb.evaluate import Evaluate +from mindsdb_sql_parser.ast.mindsdb.knowledge_base import CreateKnowledgeBase, DropKnowledgeBase +from mindsdb_sql_parser.ast.mindsdb.skills import CreateSkill, DropSkill, UpdateSkill +from mindsdb_sql_parser.exceptions import ParsingException +from mindsdb_sql_parser.ast.mindsdb.retrain_predictor import RetrainPredictor +from mindsdb_sql_parser.ast.mindsdb.finetune_predictor import FinetunePredictor +from mindsdb_sql_parser.utils import ensure_select_keyword_order, JoinType, tokens_to_string +from mindsdb_sql_parser.logger import ParserLogger + +from mindsdb_sql_parser.lexer import MindsDBLexer + +all_tokens_list = MindsDBLexer.tokens.copy() +all_tokens_list.remove('RPAREN') +all_tokens_list.remove('LPAREN') + +""" +Unfortunately the rules are not iherited from base SQLParser, because it just doesn't work with Sly due to metaclass magic. +""" + + +class MindsDBParser(Parser): + log = ParserLogger() + tokens = MindsDBLexer.tokens + + precedence = ( + ('left', OR), + ('left', AND), + ('right', UNOT), + ('left', EQUALS, NEQUALS), + ('nonassoc', LESS, LEQ, GREATER, GEQ, IN, NOT_IN, BETWEEN, IS, IS_NOT, NOT_LIKE, LIKE), + ('left', JSON_GET), + ('left', PLUS, MINUS), + ('left', STAR, DIVIDE, TYPECAST, MODULO), + ('right', UMINUS), # Unary minus operator, unary not + + ) + + # Top-level statements + @_('show', + 'start_transaction', + 'commit_transaction', + 'rollback_transaction', + 'alter_table', + 'explain', + 'set', + 'use', + 'describe', + 'create_predictor', + 'create_integration', + 'create_view', + 'create_anomaly_detection_model', + 'drop_predictor', + 'drop_datasource', + 'drop_dataset', + 'select', + 'insert', + 'union', + 'update', + 'delete', + 'evaluate', + 'drop_database', + 'drop_view', + 'drop_table', + 'create_table', + 'create_job', + 'drop_job', + 'create_chat_bot', + 'drop_chat_bot', + 'update_chat_bot', + 'create_trigger', + 'drop_trigger', + 'create_kb', + 'drop_kb', + 'create_skill', + 'drop_skill', + 'update_skill', + 'create_agent', + 'drop_agent', + 'update_agent' + ) + def query(self, p): + return p[0] + + # -- Knowledge Base -- + @_( + 'CREATE KNOWLEDGE_BASE if_not_exists_or_empty identifier USING kw_parameter_list', + 'CREATE KNOWLEDGE_BASE if_not_exists_or_empty identifier', + # from select + 'CREATE KNOWLEDGE_BASE if_not_exists_or_empty identifier FROM LPAREN select RPAREN USING kw_parameter_list', + 'CREATE KNOWLEDGE_BASE if_not_exists_or_empty identifier FROM LPAREN select RPAREN', + ) + def create_kb(self, p): + params = getattr(p, 'kw_parameter_list', {}) + from_query = getattr(p, 'select', None) + name = p.identifier + # check model and storage are in params + params = {k.lower(): v for k, v in params.items()} # case insensitive + model = params.pop('model', None) + storage = params.pop('storage', None) + + if isinstance(model, str): + # convert to identifier + storage = Identifier(storage) + + if isinstance(model, str): + # convert to identifier + model = Identifier(model) + + if_not_exists = p.if_not_exists_or_empty + + return CreateKnowledgeBase( + name=name, + model=model, + storage=storage, + from_select=from_query, + params=params, + if_not_exists=if_not_exists + ) + + @_('DROP KNOWLEDGE_BASE if_exists_or_empty identifier') + def drop_kb(self, p): + return DropKnowledgeBase(name=p.identifier, if_exists=p.if_exists_or_empty) + + # -- Skills -- + @_('CREATE SKILL if_not_exists_or_empty identifier USING kw_parameter_list') + def create_skill(self, p): + params = p.kw_parameter_list + + return CreateSkill( + name=p.identifier, + type=params.pop('type'), + params=params, + if_not_exists=p.if_not_exists_or_empty + ) + + @_('DROP SKILL if_exists_or_empty identifier') + def drop_skill(self, p): + return DropSkill(name=p.identifier, if_exists=p.if_exists_or_empty) + + @_('UPDATE SKILL identifier SET kw_parameter_list') + def update_skill(self, p): + return UpdateSkill(name=p.identifier, updated_params=p.kw_parameter_list) + + # -- Agent -- + @_('CREATE AGENT if_not_exists_or_empty identifier USING kw_parameter_list') + def create_agent(self, p): + params = p.kw_parameter_list + + return CreateAgent( + name=p.identifier, + model=params.pop('model', None), + params=params, + if_not_exists=p.if_not_exists_or_empty + ) + + @_('DROP AGENT if_exists_or_empty identifier') + def drop_agent(self, p): + return DropAgent(name=p.identifier, if_exists=p.if_exists_or_empty) + + @_('UPDATE AGENT identifier SET kw_parameter_list') + def update_agent(self, p): + return UpdateAgent(name=p.identifier, updated_params=p.kw_parameter_list) + + # -- ChatBot -- + @_('CREATE CHATBOT identifier USING kw_parameter_list') + def create_chat_bot(self, p): + params = p.kw_parameter_list + + database = Identifier(params.pop('database')) + model_param = params.pop('model', None) + agent_param = params.pop('agent', None) + model = Identifier( + model_param) if model_param is not None else None + agent = Identifier( + agent_param) if agent_param is not None else None + return CreateChatBot( + name=p.identifier, + database=database, + model=model, + agent=agent, + params=params + ) + + @_('UPDATE CHATBOT identifier SET kw_parameter_list') + def update_chat_bot(self, p): + return UpdateChatBot(name=p.identifier, updated_params=p.kw_parameter_list) + + @_('DROP CHATBOT identifier') + def drop_chat_bot(self, p): + return DropChatBot(name=p.identifier) + + # -- triggers -- + @_('CREATE TRIGGER identifier ON identifier LPAREN raw_query RPAREN') + @_('CREATE TRIGGER identifier ON identifier COLUMNS column_list LPAREN raw_query RPAREN') + def create_trigger(self, p): + query_str = tokens_to_string(p.raw_query) + + columns = None + if hasattr(p, 'column_list'): + columns = [Identifier(i) for i in p.column_list] + + return CreateTrigger( + name=p.identifier0, + table=p.identifier1, + query_str=query_str, + columns=columns + ) + + @_('DROP TRIGGER identifier') + def drop_trigger(self, p): + return DropTrigger(name=p.identifier) + + # -- Jobs -- + @_('CREATE JOB if_not_exists_or_empty identifier LPAREN raw_query RPAREN job_schedule', + 'CREATE JOB if_not_exists_or_empty identifier AS LPAREN raw_query RPAREN job_schedule', + 'CREATE JOB if_not_exists_or_empty identifier LPAREN raw_query RPAREN job_schedule IF LPAREN raw_query RPAREN', + 'CREATE JOB if_not_exists_or_empty identifier AS LPAREN raw_query RPAREN job_schedule IF LPAREN raw_query RPAREN', + 'CREATE JOB if_not_exists_or_empty identifier LPAREN raw_query RPAREN', + 'CREATE JOB if_not_exists_or_empty identifier AS LPAREN raw_query RPAREN', + 'CREATE JOB if_not_exists_or_empty identifier LPAREN raw_query RPAREN IF LPAREN raw_query RPAREN', + 'CREATE JOB if_not_exists_or_empty identifier AS LPAREN raw_query RPAREN IF LPAREN raw_query RPAREN' + ) + def create_job(self, p): + if hasattr(p, 'raw_query0'): + query_str = tokens_to_string(p.raw_query0) + if_query_str = tokens_to_string(p.raw_query1) + else: + query_str = tokens_to_string(p.raw_query) + if_query_str = None + + job_schedule = getattr(p, 'job_schedule', {}) + + start_str = None + if 'START' in job_schedule: + start_str = job_schedule.pop('START') + + end_str = None + if 'END' in job_schedule: + end_str = job_schedule.pop('END') + + repeat_str = None + if 'EVERY' in job_schedule: + repeat_str = job_schedule.pop('EVERY') + + if len(job_schedule) > 0: + raise ParsingException(f'Unexpected params: {list(job_schedule.keys())}') + + return CreateJob( + name=p.identifier, + query_str=query_str, + if_query_str=if_query_str, + start_str=start_str, + end_str=end_str, + repeat_str=repeat_str, + if_not_exists=p.if_not_exists_or_empty + ) + + @_('START string', + 'START id', + 'END string', + 'EVERY string', + 'EVERY id', + 'EVERY integer id', + 'job_schedule job_schedule') + def job_schedule(self, p): + + if isinstance(p[0], dict): + schedule = p[0] + for k in p[1].keys(): + if k in p[0]: + raise ParsingException(f'Duplicated param: {k}') + + schedule.update(p[1]) + return schedule + + param = p[0].upper() + value = p[1] + if param == 'EVERY': + # 'integer + id' mode + if hasattr(p, 'integer'): + value = f'{p[1]} {p[2]}' + + schedule = {param:value} + return schedule + + @_('DROP JOB if_exists_or_empty identifier') + def drop_job(self, p): + return DropJob(name=p.identifier, if_exists=p.if_exists_or_empty) + + # Explain + @_('EXPLAIN identifier') + def explain(self, p): + return Explain(target=p.identifier) + + # Alter table + @_('ALTER TABLE identifier id id') + def alter_table(self, p): + return AlterTable(target=p.identifier, + arg=' '.join([p.id0, p.id1])) + + # DROP VEW + @_('DROP VIEW if_exists_or_empty identifier') + def drop_view(self, p): + return DropView([p.identifier], if_exists=p.if_exists_or_empty) + + @_('DROP VIEW if_exists_or_empty enumeration') + def drop_view(self, p): + return DropView(p.enumeration, if_exists=p.if_exists_or_empty) + + # DROP DATABASE + @_('DROP DATABASE if_exists_or_empty identifier', + 'DROP PROJECT if_exists_or_empty identifier', + 'DROP SCHEMA if_exists_or_empty identifier') + def drop_database(self, p): + return DropDatabase(name=p.identifier, if_exists=p.if_exists_or_empty) + + # Transactions + + @_('START TRANSACTION', + 'BEGIN') + def start_transaction(self, p): + # https://dev.mysql.com/doc/refman/8.0/en/commit.html + return StartTransaction() + + @_('COMMIT') + def commit_transaction(self, p): + return CommitTransaction() + + @_('ROLLBACK') + def rollback_transaction(self, p): + return RollbackTransaction() + + # --- Set --- + @_('SET set_item_list') + def set(self, p): + set_list = p[1] + if len(set_list) == 1: + return set_list[0] + return Set(set_list=set_list) + + @_('set_item', + 'set_item_list COMMA set_item') + def set_item_list(self, p): + arr = getattr(p, 'set_item_list', []) + arr.append(p.set_item) + return arr + + # set names + @_('id id', + 'id constant', + 'id identifier', + 'id id COLLATE constant', + 'id id COLLATE id', + 'id constant COLLATE constant', + 'id constant COLLATE id') + def set_item(self, p): + category = p[0] + + if isinstance(p[1], (Constant, Identifier)): + value = p[1] + else: + # is id + value = Constant(p[1], with_quotes=False) + + params = {} + if hasattr(p, 'COLLATE'): + if category.lower() != 'names': + raise ParsingException(f'Expected "SET names", got "SET {category}"') + + if isinstance(p[3], Constant): + val = p[3] + else: + val = Constant(p[3], with_quotes=False) + params['COLLATE'] = val + + return Set(category=category, value=value, params=params) + + # set charset + @_('charset constant', + 'charset id') + def set_item(self, p): + if hasattr(p, 'id'): + arg = Constant(p.id, with_quotes=False) + else: + arg = p.constant + return Set(category='CHARSET', value=arg) + + @_('CHARACTER SET', + 'CHARSET', + ) + def charset(self, p): + if hasattr(p, 'SET'): + return f'{p[0]} {p[1]}' + return p[0] + + # set transaction + @_('set_scope TRANSACTION transact_property_list', + 'TRANSACTION transact_property_list') + def set_item(self, p): + isolation_level = None + access_mode = None + transact_scope = getattr(p, 'set_scope', None) + for prop in p.transact_property_list: + if prop['type'] == 'iso_level': + isolation_level = prop['value'] + else: + access_mode = prop['value'] + + params = {} + if isolation_level is not None: + params['isolation level'] = isolation_level + if access_mode is not None: + params['access_mode'] = access_mode + + return Set( + category='TRANSACTION', + scope=transact_scope, + params=params + ) + + @_('transact_property_list COMMA transact_property') + def transact_property_list(self, p): + return p.transact_property_list + [p.transact_property] + + @_('transact_property') + def transact_property_list(self, p): + return [p[0]] + + @_('ISOLATION LEVEL transact_level', + 'transact_access_mode') + def transact_property(self, p): + if hasattr(p, 'transact_level'): + return {'type':'iso_level', 'value':p.transact_level} + else: + return {'type':'access_mode', 'value':p.transact_access_mode} + + @_('REPEATABLE READ', + 'READ COMMITTED', + 'READ UNCOMMITTED', + 'SERIALIZABLE') + def transact_level(self, p): + return ' '.join([x for x in p]) + + @_('READ WRITE', + 'READ ONLY') + def transact_access_mode(self, p): + return ' '.join([x for x in p]) + + @_('identifier EQUALS expr', + 'set_scope identifier EQUALS expr', + 'variable EQUALS expr', + 'set_scope variable EQUALS expr') + def set_item(self, p): + + scope = None + name = p[0] + if hasattr(p, 'set_scope'): + scope = p.set_scope + name=p[1] + + return Set(name=name, value=p.expr, scope=scope) + + @_('GLOBAL', + 'PERSIST', + 'PERSIST_ONLY', + 'SESSION', + ) + def set_scope(self, p): + return p[0] + + # --- Show --- + @_('show WHERE expr') + def show(self, p): + command = p.show + command.where = p.expr + return command + + @_('show LIKE string') + def show(self, p): + command = p.show + command.like = p.string + return command + + @_('show FROM identifier') + def show(self, p): + command = p.show + value0 = command.from_table + value1 = p.identifier + if value0 is not None: + value1.parts = value1.parts + value0.parts + + command.from_table = value1 + return command + + @_('show IN identifier') + def show(self, p): + command = p.show + value0 = command.in_table + value1 = p.identifier + if value0 is not None: + value1.parts = value1.parts + value0.parts + + command.in_table = value1 + return command + + @_('SHOW show_category', + 'SHOW show_modifier_list show_category') + def show(self, p): + modes = getattr(p, 'show_modifier_list', None) + return Show( + category=p.show_category, + modes=modes + ) + + @_( + 'id', + 'id id', + ) + def show_category(self, p): + if hasattr(p, 'id'): + return p.id + return f"{p.id0} {p.id1}" + + # custom show commands + + @_('SHOW id id identifier') + def show(self, p): + category = p[1] + ' ' + p[2] + + if p[1].lower() == 'engine': + name = p.identifier.parts[0] + else: + name = p.identifier.to_string() + return Show( + category=category, + name=name + ) + + @_('SHOW REPLICA STATUS FOR CHANNEL id', + 'SHOW SLAVE STATUS FOR CHANNEL id', + 'SHOW REPLICA STATUS', + 'SHOW SLAVE STATUS', ) + def show(self, p): + name = getattr(p, 'id', None) + return Show( + category='REPLICA STATUS', # slave = replica + name=name + ) + + @_('show_modifier', + 'show_modifier_list show_modifier') + def show_modifier_list(self, p): + if hasattr(p, 'empty'): + return None + params = getattr(p, 'show_modifier_list', []) + params.append(p.show_modifier) + return params + + @_('EXTENDED', + 'FULL') + def show_modifier(self, p): + return p[0] + + # DELETE + @_('DELETE FROM identifier WHERE expr', + 'DELETE FROM identifier') + def delete(self, p): + where = getattr(p, 'expr', None) + + if where is not None and not isinstance(where, Operation): + raise ParsingException( + f"WHERE must contain an operation that evaluates to a boolean, got: {str(where)}") + + return Delete(table=p.identifier, where=where) + + # UPDATE + @_('UPDATE identifier SET update_parameter_list FROM LPAREN select RPAREN AS id WHERE expr', + 'UPDATE identifier SET update_parameter_list WHERE expr', + 'UPDATE identifier SET update_parameter_list') + def update(self, p): + where = getattr(p, 'expr', None) + from_select = getattr(p, 'select', None) + from_select_alias = getattr(p, 'id', None) + if from_select_alias is not None: + from_select_alias = Identifier(from_select_alias) + return Update(table=p.identifier, + update_columns=p.update_parameter_list, + from_select=from_select, + from_select_alias=from_select_alias, + where=where) + + # UPDATE + @_('UPDATE identifier ON ordering_terms FROM LPAREN select RPAREN') + def update(self, p): + keys = [i.field for i in p.ordering_terms] + return Update(table=p.identifier, + keys=keys, + from_select=p.select) + + # INSERT + @_('INSERT INTO identifier LPAREN column_list RPAREN select', + 'INSERT INTO identifier LPAREN column_list RPAREN union', + 'INSERT INTO identifier select', + 'INSERT INTO identifier union') + def insert(self, p): + columns = getattr(p, 'column_list', None) + query = p.select if hasattr(p, 'select') else p.union + return Insert(table=p.identifier, columns=columns, from_select=query) + + @_('INSERT INTO identifier LPAREN column_list RPAREN VALUES expr_list_set', + 'INSERT INTO identifier VALUES expr_list_set') + def insert(self, p): + columns = getattr(p, 'column_list', None) + return Insert(table=p.identifier, columns=columns, values=p.expr_list_set) + + @_('expr_list_set COMMA expr_list_set') + def expr_list_set(self, p): + return p.expr_list_set0 + p.expr_list_set1 + + @_('LPAREN expr_list RPAREN') + def expr_list_set(self, p): + return [p.expr_list] + + # DESCRIBE + + @_('DESCRIBE identifier') + def describe(self, p): + return Describe(value=p.identifier) + + @_('DESCRIBE JOB identifier', + 'DESCRIBE SKILL identifier', + 'DESCRIBE CHATBOT identifier', + 'DESCRIBE TRIGGER identifier', + 'DESCRIBE KNOWLEDGE_BASE identifier', + 'DESCRIBE PROJECT identifier', + 'DESCRIBE ML_ENGINE identifier', + 'DESCRIBE identifier identifier', + ) + def describe(self, p): + if isinstance(p[1], Identifier): + type = p[1].parts[-1] + else: + type = p[1] + type = type.replace(' ', '_') + return Describe(value=p[2], type=type) + + + # USE + + @_('USE identifier') + def use(self, p): + return Use(value=p.identifier) + + # CREATE VIEW + @_('CREATE VIEW if_not_exists_or_empty identifier create_view_from_table_or_nothing AS LPAREN raw_query RPAREN', + 'CREATE VIEW if_not_exists_or_empty identifier create_view_from_table_or_nothing LPAREN raw_query RPAREN') + def create_view(self, p): + query_str = tokens_to_string(p.raw_query) + + return CreateView(name=p.identifier, + from_table=p.create_view_from_table_or_nothing, + query_str=query_str, + if_not_exists=p.if_not_exists_or_empty) + + @_('FROM identifier') + def create_view_from_table_or_nothing(self, p): + return p.identifier + + @_('empty') + def create_view_from_table_or_nothing(self, p): + pass + + # DROP PREDICTOR + @_('DROP PREDICTOR if_exists_or_empty identifier', + 'DROP MODEL if_exists_or_empty identifier') + def drop_predictor(self, p): + return DropPredictor(p.identifier, if_exists=p.if_exists_or_empty) + + # DROP DATASOURCE + @_('DROP DATASOURCE if_exists_or_empty identifier') + def drop_datasource(self, p): + return DropDatasource(p.identifier, if_exists=p.if_exists_or_empty) + + # DROP DATASET + @_('DROP DATASET if_exists_or_empty identifier') + def drop_dataset(self, p): + return DropDataset(p.identifier, if_exists=p.if_exists_or_empty) + + # DROP TABLE + @_('DROP TABLE if_exists_or_empty identifier') + def drop_table(self, p): + return DropTables(tables=[p.identifier], if_exists=p.if_exists_or_empty) + + # create table + @_('id id', + 'id id DEFAULT id', + 'id id PRIMARY_KEY', + 'id id LPAREN INTEGER RPAREN', + 'id id LPAREN INTEGER RPAREN DEFAULT id', + 'PRIMARY_KEY LPAREN column_list RPAREN', + ) + def table_column(self, p): + default = None + if hasattr(p, 'DEFAULT'): + # get last element + default = p[len(p) - 1] + + is_primary_key = False + if hasattr(p, 'column_list'): + # is list of primary keys + return p.column_list + + elif hasattr(p, 'PRIMARY_KEY'): + is_primary_key = True + + return TableColumn( + name=p[0], + type=p[1], + length=getattr(p, 'INTEGER', None), + default=default, + is_primary_key=is_primary_key + ) + + @_('table_column NULL', + 'table_column NOT NULL') + def table_column(self, p): + nullable = True + if hasattr(p, 'NOT'): + nullable = False + p.table_column.nullable = nullable + return p.table_column + + @_('table_column', + 'table_column_list COMMA table_column') + def table_column_list(self, p): + items = getattr(p, 'table_column_list', []) + items.append(p.table_column) + return items + + @_('CREATE replace_or_empty TABLE if_not_exists_or_empty identifier LPAREN table_column_list RPAREN') + def create_table(self, p): + table_columns = {} + primary_keys = [] + for item in p.table_column_list: + if isinstance(item, TableColumn): + table_columns[item.name] = item + else: + primary_keys = item + for col_name in primary_keys: + if col_name in table_columns: + table_columns[col_name].is_primary_key = True + + return CreateTable( + name=p.identifier, + columns=list(table_columns.values()), + is_replace=getattr(p, 'replace_or_empty', False), + if_not_exists=getattr(p, 'if_not_exists_or_empty', False) + ) + + @_( + 'CREATE replace_or_empty TABLE if_not_exists_or_empty identifier select', + 'CREATE replace_or_empty TABLE if_not_exists_or_empty identifier LPAREN select RPAREN', + ) + def create_table(self, p): + is_replace = getattr(p, 'replace_or_empty', False) + + return CreateTable( + name=p.identifier, + is_replace=is_replace, + from_select=p.select, + if_not_exists=getattr(p, 'if_not_exists_or_empty', False) + ) + + # create predictor + + @_('create_predictor USING kw_parameter_list') + def create_predictor(self, p): + p.create_predictor.using = p.kw_parameter_list + return p.create_predictor + + @_('create_predictor HORIZON integer') + def create_predictor(self, p): + p.create_predictor.horizon = p.integer + return p.create_predictor + + @_('create_predictor WINDOW integer') + def create_predictor(self, p): + p.create_predictor.window = p.integer + return p.create_predictor + + @_('create_predictor GROUP_BY expr_list') + def create_predictor(self, p): + group_by = p.expr_list + if not isinstance(group_by, list): + group_by = [group_by] + + p.create_predictor.group_by = group_by + return p.create_predictor + + @_('create_predictor ORDER_BY ordering_terms') + def create_predictor(self, p): + p.create_predictor.order_by = p.ordering_terms + return p.create_predictor + + @_('CREATE replace_or_empty PREDICTOR if_not_exists_or_empty identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns', + 'CREATE replace_or_empty PREDICTOR if_not_exists_or_empty identifier PREDICT result_columns', + 'CREATE replace_or_empty MODEL if_not_exists_or_empty identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns', + 'CREATE replace_or_empty MODEL if_not_exists_or_empty identifier FROM LPAREN raw_query RPAREN PREDICT result_columns', + 'CREATE replace_or_empty MODEL if_not_exists_or_empty identifier PREDICT result_columns' + ) + def create_predictor(self, p): + query_str = None + if hasattr(p, 'raw_query'): + query_str = tokens_to_string(p.raw_query) + + if hasattr(p, 'identifier'): + # single identifier field + name = p.identifier + else: + name = p.identifier0 + + return CreatePredictor( + name=name, + integration_name=getattr(p, 'identifier1', None), + query_str=query_str, + targets=p.result_columns, + if_not_exists=p.if_not_exists_or_empty, + is_replace=p.replace_or_empty + ) + + # Typed models + ## Anomaly detection + @_( + 'CREATE ANOMALY DETECTION MODEL identifier', # for methods that do not require training (e.g. TimeGPT) + 'CREATE ANOMALY DETECTION MODEL identifier FROM identifier LPAREN raw_query RPAREN', + 'CREATE ANOMALY DETECTION MODEL identifier PREDICT result_columns', + 'CREATE ANOMALY DETECTION MODEL identifier PREDICT result_columns FROM identifier LPAREN raw_query RPAREN', + 'CREATE ANOMALY DETECTION MODEL identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns', + # TODO add IF_NOT_EXISTS elegantly (should be low level BNF expansion) + ) + def create_anomaly_detection_model(self, p): + + query_str = None + if hasattr(p, 'raw_query'): + query_str = tokens_to_string(p.raw_query) + + if hasattr(p, 'identifier'): + # single identifier field + name = p.identifier + else: + name = p.identifier0 + + return CreateAnomalyDetectionModel( + name=name, + targets=getattr(p, 'result_columns', None), + integration_name=getattr(p, 'identifier1', None), + query_str=query_str, + if_not_exists=False + ) + + @_('create_anomaly_detection_model USING kw_parameter_list') + def create_anomaly_detection_model(self, p): + p.create_anomaly_detection_model.using = p.kw_parameter_list + return p.create_anomaly_detection_model + + # RETRAIN PREDICTOR + + @_('RETRAIN identifier', + 'RETRAIN identifier PREDICT result_columns', + 'RETRAIN identifier FROM LPAREN raw_query RPAREN', + 'RETRAIN identifier FROM LPAREN raw_query RPAREN PREDICT result_columns', + 'RETRAIN identifier FROM identifier LPAREN raw_query RPAREN', + 'RETRAIN identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns', + 'RETRAIN MODEL identifier', + 'RETRAIN MODEL identifier PREDICT result_columns', + 'RETRAIN MODEL identifier FROM LPAREN raw_query RPAREN', + 'RETRAIN MODEL identifier FROM identifier LPAREN raw_query RPAREN', + 'RETRAIN MODEL identifier FROM LPAREN raw_query RPAREN PREDICT result_columns', + 'RETRAIN MODEL identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns') + def create_predictor(self, p): + query_str = None + if hasattr(p, 'raw_query'): + query_str = tokens_to_string(p.raw_query) + + if hasattr(p, 'identifier'): + # single identifier field + name = p.identifier + else: + name = p.identifier0 + + return RetrainPredictor( + name=name, + integration_name=getattr(p, 'identifier1', None), + query_str=query_str, + targets=getattr(p, 'result_columns', None) + ) + + @_('FINETUNE identifier FROM identifier LPAREN raw_query RPAREN', + 'FINETUNE identifier FROM LPAREN raw_query RPAREN', + 'FINETUNE MODEL identifier FROM identifier LPAREN raw_query RPAREN', + 'FINETUNE MODEL identifier FROM LPAREN raw_query RPAREN') + def create_predictor(self, p): + query_str = None + if hasattr(p, 'raw_query'): + query_str = tokens_to_string(p.raw_query) + + if hasattr(p, 'identifier'): + # single identifier field + name = p.identifier + else: + name = p.identifier0 + + return FinetunePredictor( + name=name, + integration_name=getattr(p, 'identifier1', None), + query_str=query_str, + ) + + @_('EVALUATE identifier FROM LPAREN raw_query RPAREN', + 'EVALUATE identifier FROM LPAREN raw_query RPAREN USING kw_parameter_list', ) + def evaluate(self, p): + if hasattr(p, 'identifier'): + # single identifier field + name = p.identifier + else: + name = p.identifier0 + + if hasattr(p, 'USING'): + using = p.kw_parameter_list + else: + using = None + + return Evaluate( + name=name, + query_str=tokens_to_string(p.raw_query), + using=using + ) + + # ------------ + + # ML ENGINE + # CREATE + @_('CREATE ML_ENGINE if_not_exists_or_empty identifier FROM id USING kw_parameter_list', + 'CREATE ML_ENGINE if_not_exists_or_empty identifier FROM id') + def create_integration(self, p): + return CreateMLEngine(name=p.identifier, + handler=p.id, + params=getattr(p, 'kw_parameter_list', None), + if_not_exists=p.if_not_exists_or_empty) + + # DROP + @_('DROP ML_ENGINE if_exists_or_empty identifier') + def create_integration(self, p): + return DropMLEngine(name=p.identifier, if_exists=p.if_exists_or_empty) + + # CREATE INTEGRATION + @_('CREATE replace_or_empty database_engine', + 'CREATE replace_or_empty database_engine COMMA PARAMETERS EQUALS json', + 'CREATE replace_or_empty database_engine COMMA PARAMETERS json', + 'CREATE replace_or_empty database_engine PARAMETERS EQUALS json', + 'CREATE replace_or_empty database_engine PARAMETERS json', + ) + def create_integration(self, p): + is_replace = getattr(p, 'replace_or_empty', False) + + parameters = None + if hasattr(p, 'json'): + parameters = p.json + + return CreateDatabase(name=p.database_engine['identifier'], + engine=p.database_engine['engine'], + is_replace=is_replace, + parameters=parameters, + if_not_exists=p.database_engine['if_not_exists']) + + @_('DATABASE if_not_exists_or_empty identifier', + 'DATABASE if_not_exists_or_empty identifier ENGINE string', + 'DATABASE if_not_exists_or_empty identifier ENGINE EQUALS string', + 'DATABASE if_not_exists_or_empty identifier WITH ENGINE string', + 'DATABASE if_not_exists_or_empty identifier WITH ENGINE EQUALS string', + 'DATABASE if_not_exists_or_empty identifier USING ENGINE EQUALS string', + 'PROJECT if_not_exists_or_empty identifier') + def database_engine(self, p): + engine = None + if hasattr(p, 'string'): + engine = p.string + return {'identifier':p.identifier, 'engine':engine, 'if_not_exists':p.if_not_exists_or_empty} + + # Combining + @_('select UNION select', + 'union UNION select', + 'select UNION ALL select', + 'union UNION ALL select') + def union(self, p): + unique = not hasattr(p, 'ALL') + return Union(left=p[0], right=p[2] if unique else p[3], unique=unique) + + @_('select INTERSECT select', + 'union INTERSECT select', + 'select INTERSECT ALL select', + 'union INTERSECT ALL select') + def union(self, p): + unique = not hasattr(p, 'ALL') + return Intersect(left=p[0], right=p[2] if unique else p[3], unique=unique) + @_('select EXCEPT select', + 'union EXCEPT select', + 'select EXCEPT ALL select', + 'union EXCEPT ALL select') + def union(self, p): + unique = not hasattr(p, 'ALL') + return Except(left=p[0], right=p[2] if unique else p[3], unique=unique) + + # tableau + @_('LPAREN select RPAREN') + @_('LPAREN union RPAREN') + def select(self, p): + return p[1] + + # WITH + @_('ctes select') + def select(self, p): + select = p.select + select.cte = p.ctes + return select + + @_('ctes COMMA identifier cte_columns_or_nothing AS LPAREN select RPAREN') + def ctes(self, p): + ctes = p.ctes + ctes = ctes + [ + CommonTableExpression( + name=p.identifier, + columns=p.cte_columns_or_nothing, + query=p.select) + ] + return ctes + + @_('WITH identifier cte_columns_or_nothing AS LPAREN select RPAREN', + 'WITH identifier cte_columns_or_nothing AS LPAREN union RPAREN') + def ctes(self, p): + return [ + CommonTableExpression( + name=p.identifier, + columns=p.cte_columns_or_nothing, + query=p[5]) + ] + + @_('empty') + def cte_columns_or_nothing(self, p): + pass + + @_('LPAREN enumeration RPAREN') + def cte_columns_or_nothing(self, p): + return p.enumeration + + # SELECT + @_('select FOR UPDATE') + def select(self, p): + select = p.select + ensure_select_keyword_order(select, 'MODE') + select.mode = 'FOR UPDATE' + return select + + @_('select OFFSET constant') + def select(self, p): + select = p.select + if select.offset is not None: + raise ParsingException(f'OFFSET already specified for this query') + ensure_select_keyword_order(select, 'OFFSET') + if not isinstance(p.constant.value, int): + raise ParsingException(f'OFFSET must be an integer value, got: {p.constant.value}') + + select.offset = p.constant + return select + + @_('select LIMIT constant COMMA constant') + def select(self, p): + select = p.select + ensure_select_keyword_order(select, 'LIMIT') + if not isinstance(p.constant0.value, int) or not isinstance(p.constant1.value, int): + raise ParsingException(f'LIMIT must have integer arguments, got: {p.constant0.value}, {p.constant1.value}') + select.offset = p.constant0 + select.limit = p.constant1 + return select + + @_('select LIMIT constant') + def select(self, p): + select = p.select + ensure_select_keyword_order(select, 'LIMIT') + if not isinstance(p.constant.value, int): + raise ParsingException(f'LIMIT must be an integer value, got: {p.constant.value}') + select.limit = p.constant + return select + + @_('select ORDER_BY ordering_terms') + def select(self, p): + select = p.select + ensure_select_keyword_order(select, 'ORDER BY') + select.order_by = p.ordering_terms + return select + + @_('ordering_terms COMMA ordering_term') + def ordering_terms(self, p): + terms = p.ordering_terms + terms.append(p.ordering_term) + return terms + + @_('ordering_term') + def ordering_terms(self, p): + return [p.ordering_term] + + @_('ordering_term NULLS_FIRST') + def ordering_term(self, p): + p.ordering_term.nulls = p.NULLS_FIRST + return p.ordering_term + + @_('ordering_term NULLS_LAST') + def ordering_term(self, p): + p.ordering_term.nulls = p.NULLS_LAST + return p.ordering_term + + @_('ordering_term DESC') + def ordering_term(self, p): + item = p.ordering_term + item.direction = 'DESC' + return item + + @_('ordering_term ASC') + def ordering_term(self, p): + item = p.ordering_term + item.direction = 'ASC' + return item + + @_('expr') + def ordering_term(self, p): + return OrderBy(field=p.expr, direction='default') + + @_('select USING kw_parameter_list') + def select(self, p): + p.select.using = p.kw_parameter_list + return p.select + + @_('select HAVING expr') + def select(self, p): + select = p.select + ensure_select_keyword_order(select, 'HAVING') + having = p.expr + if not isinstance(having, Operation): + raise ParsingException( + f"HAVING must contain an operation that evaluates to a boolean, got: {str(having)}") + select.having = having + return select + + @_('select GROUP_BY expr_list') + def select(self, p): + select = p.select + ensure_select_keyword_order(select, 'GROUP BY') + group_by = p.expr_list + if not isinstance(group_by, list): + group_by = [group_by] + + select.group_by = group_by + return select + + @_('select WHERE expr') + def select(self, p): + select = p.select + ensure_select_keyword_order(select, 'WHERE') + where_expr = p.expr + if not isinstance(where_expr, Operation): + raise ParsingException( + f"WHERE must contain an operation that evaluates to a boolean, got: {str(where_expr)}") + select.where = where_expr + return select + + @_('select FROM from_table_aliased', + 'select FROM join_tables_implicit', + 'select FROM join_tables') + def select(self, p): + select = p.select + ensure_select_keyword_order(select, 'FROM') + select.from_table = p[2] + return select + + # --- join --- + @_('from_table_aliased join_clause from_table_aliased', + 'join_tables join_clause from_table_aliased') + def join_tables(self, p): + return Join(left=p[0], + right=p[2], + join_type=p.join_clause) + + @_('from_table_aliased join_clause from_table_aliased ON expr', + 'join_tables join_clause from_table_aliased ON expr') + def join_tables(self, p): + return Join(left=p[0], + right=p[2], + join_type=p.join_clause, + condition=p.expr) + + @_('from_table_aliased COMMA from_table_aliased', + 'join_tables_implicit COMMA from_table_aliased') + def join_tables_implicit(self, p): + return Join(left=p[0], + right=p[2], + join_type=JoinType.INNER_JOIN, + implicit=True) + + @_('from_table AS identifier', + 'from_table identifier', + 'from_table AS dquote_string', + 'from_table dquote_string', + 'from_table') + def from_table_aliased(self, p): + entity = p.from_table + if hasattr(p, 'identifier'): + entity.alias = p.identifier + if hasattr(p, 'dquote_string'): + entity.alias = Identifier(p.dquote_string) + return entity + + # native query + @_('identifier LPAREN raw_query RPAREN') + def from_table(self, p): + query = NativeQuery( + integration=p.identifier, + query=tokens_to_string(p.raw_query) + ) + return query + + @_('LPAREN query RPAREN') + @_('LPAREN query RPAREN AS id') + @_('LPAREN query RPAREN AS id LPAREN column_list RPAREN') + def from_table(self, p): + query = p.query + query.parentheses = True + if hasattr(p, 'id'): + query.alias = Identifier(parts=[p.id]) + if hasattr(p, 'column_list'): + for i, col in enumerate(p.column_list): + if i >= len(query.targets): + break + query.targets[i].alias = Identifier(parts=[col]) + return query + + # keywords for table + @_('PLUGINS', + 'ENGINES') + def from_table(self, p): + return Identifier.from_path_str(p[0]) + + @_('identifier') + def from_table(self, p): + return p.identifier + + @_('parameter') + def from_table(self, p): + return p.parameter + + @_('JOIN', + 'LEFT JOIN', + 'RIGHT JOIN', + 'INNER JOIN', + 'FULL JOIN', + 'CROSS JOIN', + 'OUTER JOIN', + 'LEFT OUTER JOIN', + 'FULL OUTER JOIN', + ) + def join_clause(self, p): + return ' '.join([x for x in p]) + + @_('SELECT DISTINCT result_columns') + def select(self, p): + targets = p.result_columns + return Select(targets=targets, distinct=True) + + @_('SELECT result_columns') + def select(self, p): + targets = p.result_columns + return Select(targets=targets) + + @_('result_columns COMMA result_column') + def result_columns(self, p): + p.result_columns.append(p.result_column) + return p.result_columns + + @_('result_column') + def result_columns(self, p): + return [p.result_column] + + @_('result_column AS identifier', + 'result_column identifier', + 'result_column AS dquote_string', + 'result_column dquote_string', + 'result_column AS quote_string', + 'result_column quote_string') + def result_column(self, p): + col = p.result_column + # if col.alias: + # raise ParsingException(f'Attempt to provide two aliases for {str(col)}') + if hasattr(p, 'dquote_string'): + alias = Identifier(p.dquote_string) + elif hasattr(p, 'quote_string'): + alias = Identifier(p.quote_string) + else: + alias = p.identifier + col.alias = alias + return col + + @_('LPAREN select RPAREN') + def result_column(self, p): + select = p.select + select.parentheses = True + return select + + @_('star') + def result_column(self, p): + return p.star + + @_('expr', + 'function', + 'window_function') + def result_column(self, p): + return p[0] + + @_('column_list COMMA id', + 'id') + def column_list(self, p): + column_list = getattr(p, 'column_list', []) + column_list.append(p.id) + return column_list + + # case + @_('CASE case_conditions ELSE expr END', + 'CASE case_conditions END') + def case(self, p): + return Case(rules=p.case_conditions, default=getattr(p, 'expr', None)) + + @_('CASE expr case_conditions ELSE expr END', + 'CASE expr case_conditions END') + def case(self, p): + if hasattr(p, 'expr'): + arg, default = p.expr, None + else: + arg, default = p.expr0, p.expr1 + return Case(rules=p.case_conditions, default=default, arg=arg) + + @_('case_condition', + 'case_conditions case_condition') + def case_conditions(self, p): + arr = getattr(p, 'case_conditions', []) + arr.append(p.case_condition) + return arr + + @_('WHEN expr THEN expr') + def case_condition(self, p): + return [p.expr0, p.expr1] + + # Window function + @_('expr OVER LPAREN window RPAREN', + 'expr OVER LPAREN window id BETWEEN id id AND id id RPAREN') + def window_function(self, p): + + modifier = None + if hasattr(p, 'BETWEEN'): + modifier = f'{p.id0} BETWEEN {p.id1} {p.id2} AND {p.id3} {p.id4}' + return WindowFunction( + function=p.expr, + order_by=p.window.get('order_by'), + partition=p.window.get('partition'), + modifier=modifier, + ) + + @_('window PARTITION_BY expr_list') + def window(self, p): + window = p.window + part_by = p.expr_list + if not isinstance(part_by, list): + part_by = [part_by] + + window['partition'] = part_by + return window + + @_('window ORDER_BY ordering_terms') + def window(self, p): + window = p.window + window['order_by'] = p.ordering_terms + return window + + @_('empty') + def window(self, p): + return {} + + # OPERATIONS + + @_('LPAREN select RPAREN') + def expr(self, p): + select = p.select + select.parentheses = True + return select + + @_('LPAREN expr RPAREN') + def expr(self, p): + if isinstance(p.expr, ASTNode): + p.expr.parentheses = True + return p.expr + + @_('identifier LPAREN expr FROM expr RPAREN') + def function(self, p): + return Function(op=p[0].parts[0], args=[p.expr0], from_arg=p.expr1) + + @_('identifier LPAREN expr FROM expr FOR expr RPAREN') + def function(self, p): + return Function(op=p[0].parts[0], args=[p.expr0, p.expr1, p.expr2]) + + @_('DATABASE LPAREN RPAREN') + def function(self, p): + return Function(op=p.DATABASE, args=[]) + + @_('identifier LPAREN DISTINCT expr_list RPAREN') + def function(self, p): + return Function(op=p[0].parts[0], distinct=True, args=p.expr_list) + + @_( + 'function_name LPAREN expr_list_or_nothing RPAREN', + 'identifier LPAREN expr_list_or_nothing RPAREN', + 'identifier LPAREN star RPAREN') + def function(self, p): + if hasattr(p, 'star'): + args = [p.star] + else: + args = p.expr_list_or_nothing + if not args: + args = [] + for i, arg in enumerate(args): + if ( + isinstance(arg, Identifier) + and len(arg.parts) == 1 + and arg.parts[0].lower() == 'last' + ): + args[i] = Last() + + namespace = None + if hasattr(p, 'identifier'): + if len(p.identifier.parts) > 1: + namespace = p.identifier.parts[0] + name = p.identifier.parts[-1] + else: + name = p.function_name + return Function(op=name, args=args, namespace=namespace) + + @_('INTERVAL string') + @_('INTERVAL string id') + @_('INTERVAL integer id') + def expr(self, p): + if hasattr(p, 'id'): + if hasattr(p, 'integer'): + info = f'{p.integer} {p.id}' + else: + info = f'{p.string} {p.id}' + else: + info = p.string + return Interval(info) + + @_('EXISTS LPAREN select RPAREN') + def function(self, p): + return Exists(p.select) + + @_('NOT_EXISTS LPAREN select RPAREN') + def function(self, p): + return NotExists(p.select) + + + # arguments are optional in functions, so that things like `select database()` are possible + @_('expr BETWEEN expr AND expr') + def expr(self, p): + return BetweenOperation(args=(p.expr0, p.expr1, p.expr2)) + + @_('expr_list') + def expr_list_or_nothing(self, p): + return p.expr_list + + @_('empty') + def expr_list_or_nothing(self, p): + pass + + @_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN') + @_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN') + def expr(self, p): + if hasattr(p, 'integer'): + precision=[p.integer] + else: + precision=[p.integer0, p.integer1] + return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision) + + @_('CAST LPAREN expr AS id RPAREN') + def expr(self, p): + return TypeCast(arg=p.expr, type_name=str(p.id)) + + @_('CONVERT LPAREN expr COMMA id RPAREN', + 'CONVERT LPAREN expr USING id RPAREN') + def expr(self, p): + return TypeCast(arg=p.expr, type_name=str(p.id)) + + @_('DATE string') + def expr(self, p): + return TypeCast(arg=Constant(p.string), type_name=p.DATE) + + @_('expr TYPECAST id') + @_('expr TYPECAST DATE') + def expr(self, p): + return TypeCast(arg=p.expr, type_name=p[2]) + + @_('enumeration') + def expr_list(self, p): + return p.enumeration + + @_('expr') + def expr_list(self, p): + return [p.expr] + + @_('LPAREN enumeration RPAREN') + def expr(self, p): + tup = Tuple(items=p.enumeration) + return tup + + @_('STAR') + def star(self, p): + return Star() + + @_('expr PLUS expr', + 'expr MINUS expr', + 'expr MATCH expr', + 'expr NOT_MATCH expr', + 'expr STAR expr', + 'expr DIVIDE expr', + 'expr MODULO expr', + 'expr EQUALS expr', + 'expr NEQUALS expr', + 'expr GEQ expr', + 'expr GREATER expr', + 'expr GEQ LAST', + 'expr GREATER LAST', + 'expr LEQ expr', + 'expr LESS expr', + 'expr AND expr', + 'expr OR expr', + 'expr IS_NOT expr', + 'expr NOT expr', + 'expr IS expr', + 'expr LIKE expr', + 'expr NOT_LIKE expr', + 'expr CONCAT expr', + 'expr JSON_GET constant', + 'expr JSON_GET_STR constant', + 'expr NOT_IN expr', + 'expr IN expr',) + def expr(self, p): + if hasattr(p, 'LAST'): + arg1 = Last() + else: + arg1 = p[2] + return BinaryOperation(op=p[1], args=(p[0], arg1)) + + @_('MINUS expr %prec UMINUS', + 'NOT expr %prec UNOT', ) + def expr(self, p): + return UnaryOperation(op=p[0], args=(p.expr,)) + + @_('MINUS constant %prec UMINUS') + def constant(self, p): + return Constant(-p.constant.value) + + # update fields list + @_('update_parameter', + 'update_parameter_list COMMA update_parameter') + def update_parameter_list(self, p): + params = getattr(p, 'update_parameter_list', {}) + params.update(p.update_parameter) + return params + + @_('id EQUALS expr') + def update_parameter(self, p): + return {p.id:p.expr} + + # EXPRESSIONS + + @_('enumeration COMMA expr') + def enumeration(self, p): + return p.enumeration + [p.expr] + + @_('expr COMMA expr') + def enumeration(self, p): + return [p.expr0, p.expr1] + + @_('identifier', + 'parameter', + 'constant', + 'latest', + 'case', + 'function') + def expr(self, p): + return p[0] + + @_('LATEST') + def latest(self, p): + return Latest() + + @_('NULL') + def constant(self, p): + return NullConstant() + + @_('TRUE') + def constant(self, p): + return Constant(value=True) + + @_('FALSE') + def constant(self, p): + return Constant(value=False) + + @_('integer') + def constant(self, p): + return Constant(value=int(p.integer)) + + @_('float') + def constant(self, p): + return Constant(value=float(p.float)) + + @_('string') + def constant(self, p): + return Constant(value=str(p[0])) + + # param list + + @_('id LPAREN kw_parameter_list RPAREN') + def object(self, p): + return Object(type=p.id, params=p.kw_parameter_list) + + @_('kw_parameter', + 'kw_parameter_list COMMA kw_parameter') + def kw_parameter_list(self, p): + params = getattr(p, 'kw_parameter_list', {}) + params.update(p.kw_parameter) + return params + + @_('identifier EQUALS object', + 'identifier EQUALS json_value', + 'identifier EQUALS identifier') + def kw_parameter(self, p): + key = getattr(p, 'identifier', None) or getattr(p, 'identifier0', None) + assert key is not None + key = '.'.join(key.parts) + return {key:p[2]} + + # json + + @_('LBRACE json_element_list RBRACE', + 'LBRACE RBRACE') + def json(self, p): + params = getattr(p, 'json_element_list', {}) + return params + + @_('json_element', + 'json_element_list COMMA json_element') + def json_element_list(self, p): + params = getattr(p, 'json_element_list', {}) + params.update(p.json_element) + return params + + @_('string COLON json_value') + def json_element(self, p): + return {p.string:p.json_value} + + # json_array + + @_('LBRACKET json_array_list RBRACKET', + 'LBRACKET RBRACKET') + def json_array(self, p): + arr = getattr(p, 'json_array_list', []) + return arr + + @_('json_value', + 'json_array_list COMMA json_value') + def json_array_list(self, p): + arr = getattr(p, 'json_array_list', []) + arr.append(p.json_value) + return arr + + @_('float', + 'string', + 'integer', + 'NULL', + 'TRUE', + 'FALSE', + 'json_array', + 'json') + def json_value(self, p): + + if hasattr(p, 'NULL'): + return None + elif hasattr(p, 'TRUE'): + return True + elif hasattr(p, 'FALSE'): + return False + return p[0] + + @_('identifier DOT identifier', + 'identifier DOT integer', + 'identifier DOT dquote_string', + 'identifier DOT star') + def identifier(self, p): + node = p[0] + if isinstance(p[2], Star): + node.parts.append(p[2]) + elif isinstance(p[2], int): + node.parts.append(str(p[2])) + elif isinstance(p[2], str): + node.parts.append(p[2]) + else: + node.parts += p[2].parts + return node + + @_('quote_string', + 'dquote_string') + def string(self, p): + return p[0] + + @_('id', 'dquote_string') + def identifier(self, p): + value = p[0] + return Identifier.from_path_str(value) + + @_('PARAMETER') + def parameter(self, p): + return Parameter(value=p.PARAMETER) + + @_('id', + 'FULL', + 'RIGHT', + 'LEFT') + def function_name(self, p): + return p[0] + + # convert to types + @_('ID', + 'BEGIN', + 'CAST', + 'DATE', + 'CHANNEL', + 'CHARSET', + 'CODE', + 'COLLATION', + 'COLUMNS', + 'COMMIT', + 'COMMITTED', + 'DATASET', + 'DATASETS', + 'DATABASE', + 'DATABASES', + 'DATASOURCE', + 'DATASOURCES', + 'ENGINE', + 'ENGINES', + 'EXTENDED', + 'FIELDS', + 'GLOBAL', + 'HORIZON', + 'HOSTS', + 'INDEXES', + 'INDEX', + 'INTEGRATION', + 'INTEGRATIONS', + 'INTERVAL', + 'ISOLATION', + 'KEYS', + 'LATEST', + 'LAST', + 'LEVEL', + 'LOGS', + 'MASTER', + 'MUTEX', + 'OFFSET', + 'ONLY', + 'OPEN', + 'PARAMETERS', + 'PERSIST', + 'PLUGINS', + 'PREDICT', + 'PREDICTOR', + 'PREDICTORS', + 'PRIVILEGES', + 'PROCESSLIST', + 'PROFILES', + 'PUBLICATION', + 'PUBLICATIONS', + 'REPEATABLE', + 'REPLACE', + 'REPLICA', + 'REPLICAS', + 'RETRAIN', + 'ROLLBACK', + 'SERIALIZABLE', + 'SESSION', + 'SCHEMA', + 'SLAVE', + 'START', + 'STATUS', + 'STORAGE', + 'STREAM', + 'STREAMS', + 'TABLES', + 'TABLE', + 'TRAIN', + 'TRANSACTION', + 'TRIGGERS', + 'UNCOMMITTED', + 'VARIABLES', + 'VIEW', + 'VIEWS', + 'WARNINGS', + 'MODEL', + 'DEFAULT', + 'MODELS', + 'AGENT', + 'SCHEMAS', + 'FUNCTION', + 'charset', + 'PROCEDURE', + 'ML_ENGINES', + 'HANDLERS', + 'BINARY', + 'KNOWLEDGE_BASES', + 'ALL', + 'CREATE', + ) + def id(self, p): + return p[0] + + @_('FLOAT') + def float(self, p): + return float(p[0]) + + @_('INTEGER') + def integer(self, p): + return int(p[0]) + + @_('QUOTE_STRING') + def quote_string(self, p): + return p[0].strip('\'') + + @_('DQUOTE_STRING') + def dquote_string(self, p): + return p[0].strip('\"') + + # for raw query + + @_('LPAREN raw_query RPAREN') + def raw_query(self, p): + return [p._slice[0]] + p[1] + [p._slice[2]] + + @_('raw_query LPAREN RPAREN') + def raw_query(self, p): + return p[0] + [p._slice[1], p._slice[2]] + + @_('raw_query raw_query') + def raw_query(self, p): + return p[0] + p[1] + + @_('variable') + def table_or_subquery(self, p): + return p.variable + + @_('variable') + def expr(self, p): + return p.variable + + @_('SYSTEM_VARIABLE') + def variable(self, p): + return Variable(value=p.SYSTEM_VARIABLE, is_system_var=True) + + @_('VARIABLE') + def variable(self, p): + return Variable(value=p.VARIABLE) + + @_( + 'OR REPLACE', + 'empty' + ) + def replace_or_empty(self, p): + if hasattr(p, 'REPLACE'): + return True + return False + + @_( + 'IF NOT_EXISTS', + 'empty' + ) + def if_not_exists_or_empty(self, p): + if hasattr(p, 'NOT_EXISTS'): + return True + return False + + @_( + 'IF EXISTS', + 'empty' + ) + def if_exists_or_empty(self, p): + if hasattr(p, 'EXISTS'): + return True + return False + + @_(*all_tokens_list) + def raw_query(self, p): + return p._slice + + @_('') + def empty(self, p): + pass + + def error(self, p, expected_tokens=None): + + if not hasattr(self, 'used_tokens'): + # failback mode if user has another sly version module installed + if p: + raise ParsingException(f"Syntax error at token {p.type}: \"{p.value}\"") + else: + raise ParsingException("Syntax error at EOF") + + # save error info for future usage + self.error_info = dict( + tokens=self.used_tokens.copy() + list(self.tokens), + bad_token=p, + expected_tokens=expected_tokens + ) + # don't raise exception + return diff --git a/mindsdb_sql_parser/utils.py b/mindsdb_sql_parser/utils.py new file mode 100644 index 0000000..8420249 --- /dev/null +++ b/mindsdb_sql_parser/utils.py @@ -0,0 +1,91 @@ +from mindsdb_sql_parser.exceptions import ParsingException + + +def indent(level): + return ' ' * level + + +def ensure_select_keyword_order(select, operation): + op_to_attr = { + 'FROM': select.from_table, + 'WHERE': select.where, + 'GROUP BY': select.group_by, + 'HAVING': select.having, + 'ORDER BY': select.order_by, + 'LIMIT': select.limit, + 'OFFSET': select.offset, + 'MODE': select.mode, + } + + requirements = { + 'WHERE': ['FROM'], + 'GROUP BY': ['FROM'], + 'ORDER BY': ['FROM'], + # 'HAVING': ['GROUP BY'], + } + + precedence = ['FROM', 'WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT', 'OFFSET', 'MODE'] + + if op_to_attr[operation]: + raise ParsingException(f"Duplicate {operation} clause. Only one {operation} allowed per SELECT.") + + op_requires = requirements.get(operation, []) + + for req in op_requires: + if not op_to_attr[req]: + raise ParsingException(f"{operation} requires {req}") + + op_precedence_pos = precedence.index(operation) + + for next_op in precedence[op_precedence_pos:]: + if op_to_attr[next_op]: + raise ParsingException(f"{operation} must go before {next_op}") + + +class JoinType: + JOIN = 'JOIN' + INNER_JOIN = 'INNER JOIN' + OUTER_JOIN = 'OUTER JOIN' + CROSS_JOIN = 'CROSS JOIN' + LEFT_JOIN = 'LEFT JOIN' + RIGHT_JOIN = 'RIGHT JOIN' + FULL_JOIN = 'FULL JOIN' + + +def to_single_line(text): + text = '\t'.join([line.strip() for line in text.split('\n')]) + text = text.replace('\t', ' ') + text = ' '.join(text.split()) + return text + + +def tokens_to_string(tokens): + # converts list of token (after lexer) to original string + + line_num = tokens[0].lineno + shift = tokens[0].index + last_pos = 0 + content, line = '', '' + + for token in tokens: + if token.lineno != line_num: + # go to new line + content += line + '\n' + line = '' + line_num = token.lineno + + # because sly parser store only absolute position index: + # memorizing last token index to shift next lne + shift = last_pos + 1 + + # filling space between tokens + line += ' '*(token.index - shift - len(line)) + + # add token + line += token.value + + last_pos = token.index + len(token.value) + + # last line + content += line + return content diff --git a/requirements_test.txt b/requirements_test.txt new file mode 100644 index 0000000..bfd12b3 --- /dev/null +++ b/requirements_test.txt @@ -0,0 +1 @@ +pytest>=5.4.3 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..63a7c93 --- /dev/null +++ b/setup.py @@ -0,0 +1,22 @@ +import setuptools + +about = {} +with open("mindsdb_sql_parser/__about__.py") as fp: + exec(fp.read(), about) + +setuptools.setup( + name=about['__title__'], + version=about['__version__'], + url=about['__github__'], + download_url=about['__pypi__'], + license=about['__license__'], + author=about['__author__'], + author_email=about['__email__'], + description=about['__description__'], + packages=setuptools.find_packages(exclude=('tests*',)), + classifiers=[ + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", + ], + python_requires=">=3.6" +) diff --git a/sly/LICENSE b/sly/LICENSE new file mode 100644 index 0000000..1db3028 --- /dev/null +++ b/sly/LICENSE @@ -0,0 +1,39 @@ +SLY (Sly Lex-Yacc) + +Copyright (C) 2016-2022 +David M. Beazley (Dabeaz LLC) +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. +* Neither the name of the David Beazley or Dabeaz LLC may be used to + endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + + + + + + + diff --git a/sly/__init__.py b/sly/__init__.py new file mode 100644 index 0000000..8a62454 --- /dev/null +++ b/sly/__init__.py @@ -0,0 +1,6 @@ + +from .lex import * +from .yacc import * + +__version__ = "0.5" +__all__ = [ *lex.__all__, *yacc.__all__ ] diff --git a/sly/ast.py b/sly/ast.py new file mode 100644 index 0000000..0a8f8c3 --- /dev/null +++ b/sly/ast.py @@ -0,0 +1,25 @@ +# sly/ast.py +import sys + +class AST0(object): + + @classmethod + def __init_subclass__(cls, **kwargs): + mod = sys.modules[cls.__module__] + if not hasattr(cls, '__annotations__'): + return + + hints = list(cls.__annotations__.items()) + + def __init__(self, *args, **kwargs): + if len(hints) != len(args): + raise TypeError(f'Expected {len(hints)} arguments') + for arg, (name, val) in zip(args, hints): + if isinstance(val, str): + val = getattr(mod, val) + if not isinstance(arg, val): + raise TypeError(f'{name} argument must be {val}') + setattr(self, name, arg) + + cls.__init__ = __init__ + diff --git a/sly/lex.py b/sly/lex.py new file mode 100644 index 0000000..2eb0af3 --- /dev/null +++ b/sly/lex.py @@ -0,0 +1,462 @@ +# ----------------------------------------------------------------------------- +# sly: lex.py +# +# Copyright (C) 2016 - 2018 +# David M. Beazley (Dabeaz LLC) +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of the David Beazley or Dabeaz LLC may be used to +# endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- + +__all__ = ['Lexer', 'LexerStateChange', 'Token'] + +import re +import copy + +class LexError(Exception): + ''' + Exception raised if an invalid character is encountered and no default + error handler function is defined. The .text attribute of the exception + contains all remaining untokenized text. The .error_index is the index + location of the error. + ''' + def __init__(self, message, text, error_index): + self.args = (message,) + self.text = text + self.error_index = error_index + +class PatternError(Exception): + ''' + Exception raised if there's some kind of problem with the specified + regex patterns in the lexer. + ''' + pass + +class LexerBuildError(Exception): + ''' + Exception raised if there's some sort of problem building the lexer. + ''' + pass + +class LexerStateChange(Exception): + ''' + Exception raised to force a lexing state change + ''' + def __init__(self, newstate, tok=None): + self.newstate = newstate + self.tok = tok + +class Token(object): + ''' + Representation of a single token. + ''' + __slots__ = ('type', 'value', 'lineno', 'index', 'end') + def __repr__(self): + return f'Token(type={self.type!r}, value={self.value!r}, lineno={self.lineno}, index={self.index}, end={self.end})' + +class TokenStr(str): + @staticmethod + def __new__(cls, value, key=None, remap=None): + self = super().__new__(cls, value) + self.key = key + self.remap = remap + return self + + # Implementation of TOKEN[value] = NEWTOKEN + def __setitem__(self, key, value): + if self.remap is not None: + self.remap[self.key, key] = value + + # Implementation of del TOKEN[value] + def __delitem__(self, key): + if self.remap is not None: + self.remap[self.key, key] = self.key + +class _Before: + def __init__(self, tok, pattern): + self.tok = tok + self.pattern = pattern + +class LexerMetaDict(dict): + ''' + Special dictionary that prohibits duplicate definitions in lexer specifications. + ''' + def __init__(self): + self.before = { } + self.delete = [ ] + self.remap = { } + + def __setitem__(self, key, value): + if isinstance(value, str): + value = TokenStr(value, key, self.remap) + + if isinstance(value, _Before): + self.before[key] = value.tok + value = TokenStr(value.pattern, key, self.remap) + + if key in self and not isinstance(value, property): + prior = self[key] + if isinstance(prior, str): + if callable(value): + value.pattern = prior + else: + raise AttributeError(f'Name {key} redefined') + + super().__setitem__(key, value) + + def __delitem__(self, key): + self.delete.append(key) + if key not in self and key.isupper(): + pass + else: + return super().__delitem__(key) + + def __getitem__(self, key): + if key not in self and key.split('ignore_')[-1].isupper() and key[:1] != '_': + return TokenStr(key, key, self.remap) + else: + return super().__getitem__(key) + +class LexerMeta(type): + ''' + Metaclass for collecting lexing rules + ''' + @classmethod + def __prepare__(meta, name, bases): + d = LexerMetaDict() + + def _(pattern, *extra): + patterns = [pattern, *extra] + def decorate(func): + pattern = '|'.join(f'({pat})' for pat in patterns ) + if hasattr(func, 'pattern'): + func.pattern = pattern + '|' + func.pattern + else: + func.pattern = pattern + return func + return decorate + + d['_'] = _ + d['before'] = _Before + return d + + def __new__(meta, clsname, bases, attributes): + del attributes['_'] + del attributes['before'] + + # Create attributes for use in the actual class body + cls_attributes = { str(key): str(val) if isinstance(val, TokenStr) else val + for key, val in attributes.items() } + cls = super().__new__(meta, clsname, bases, cls_attributes) + + # Attach various metadata to the class + cls._attributes = dict(attributes) + cls._remap = attributes.remap + cls._before = attributes.before + cls._delete = attributes.delete + cls._build() + return cls + +class Lexer(metaclass=LexerMeta): + # These attributes may be defined in subclasses + tokens = set() + literals = set() + ignore = '' + reflags = 0 + regex_module = re + + _token_names = set() + _token_funcs = {} + _ignored_tokens = set() + _remapping = {} + _delete = {} + _remap = {} + + # Internal attributes + __state_stack = None + __set_state = None + + @classmethod + def _collect_rules(cls): + # Collect all of the rules from class definitions that look like token + # information. There are a few things that govern this: + # + # 1. Any definition of the form NAME = str is a token if NAME is + # is defined in the tokens set. + # + # 2. Any definition of the form ignore_NAME = str is a rule for an ignored + # token. + # + # 3. Any function defined with a 'pattern' attribute is treated as a rule. + # Such functions can be created with the @_ decorator or by defining + # function with the same name as a previously defined string. + # + # This function is responsible for keeping rules in order. + + # Collect all previous rules from base classes + rules = [] + + for base in cls.__bases__: + if isinstance(base, LexerMeta): + rules.extend(base._rules) + + # Dictionary of previous rules + existing = dict(rules) + + for key, value in cls._attributes.items(): + if (key in cls._token_names) or key.startswith('ignore_') or hasattr(value, 'pattern'): + if callable(value) and not hasattr(value, 'pattern'): + raise LexerBuildError(f"function {value} doesn't have a regex pattern") + + if key in existing: + # The definition matches something that already existed in the base class. + # We replace it, but keep the original ordering + n = rules.index((key, existing[key])) + rules[n] = (key, value) + existing[key] = value + + elif isinstance(value, TokenStr) and key in cls._before: + before = cls._before[key] + if before in existing: + # Position the token before another specified token + n = rules.index((before, existing[before])) + rules.insert(n, (key, value)) + else: + # Put at the end of the rule list + rules.append((key, value)) + existing[key] = value + else: + rules.append((key, value)) + existing[key] = value + + elif isinstance(value, str) and not key.startswith('_') and key not in {'ignore', 'literals'}: + raise LexerBuildError(f'{key} does not match a name in tokens') + + # Apply deletion rules + rules = [ (key, value) for key, value in rules if key not in cls._delete ] + cls._rules = rules + + @classmethod + def _build(cls): + ''' + Build the lexer object from the collected tokens and regular expressions. + Validate the rules to make sure they look sane. + ''' + if 'tokens' not in vars(cls): + raise LexerBuildError(f'{cls.__qualname__} class does not define a tokens attribute') + + # Pull definitions created for any parent classes + cls._token_names = cls._token_names | set(cls.tokens) + cls._ignored_tokens = set(cls._ignored_tokens) + cls._token_funcs = dict(cls._token_funcs) + cls._remapping = dict(cls._remapping) + + for (key, val), newtok in cls._remap.items(): + if key not in cls._remapping: + cls._remapping[key] = {} + cls._remapping[key][val] = newtok + + remapped_toks = set() + for d in cls._remapping.values(): + remapped_toks.update(d.values()) + + undefined = remapped_toks - set(cls._token_names) + if undefined: + missing = ', '.join(undefined) + raise LexerBuildError(f'{missing} not included in token(s)') + + cls._collect_rules() + + parts = [] + for tokname, value in cls._rules: + if tokname.startswith('ignore_'): + tokname = tokname[7:] + cls._ignored_tokens.add(tokname) + + if isinstance(value, str): + pattern = value + + elif callable(value): + cls._token_funcs[tokname] = value + pattern = getattr(value, 'pattern') + else: + continue + + # Form the regular expression component + part = f'(?P<{tokname}>{pattern})' + + # Make sure the individual regex compiles properly + try: + cpat = cls.regex_module.compile(part, cls.reflags) + except Exception as e: + raise PatternError(f'Invalid regex for token {tokname}') from e + + # Verify that the pattern doesn't match the empty string + if cpat.match(''): + raise PatternError(f'Regex for token {tokname} matches empty input') + + parts.append(part) + + if not parts: + return + + # Form the master regular expression + #previous = ('|' + cls._master_re.pattern) if cls._master_re else '' + # cls._master_re = cls.regex_module.compile('|'.join(parts) + previous, cls.reflags) + cls._master_re = cls.regex_module.compile('|'.join(parts), cls.reflags) + + # Verify that that ignore and literals specifiers match the input type + if not isinstance(cls.ignore, str): + raise LexerBuildError('ignore specifier must be a string') + + if not all(isinstance(lit, str) for lit in cls.literals): + raise LexerBuildError('literals must be specified as strings') + + def begin(self, cls): + ''' + Begin a new lexer state + ''' + assert isinstance(cls, LexerMeta), "state must be a subclass of Lexer" + if self.__set_state: + self.__set_state(cls) + self.__class__ = cls + + def push_state(self, cls): + ''' + Push a new lexer state onto the stack + ''' + if self.__state_stack is None: + self.__state_stack = [] + self.__state_stack.append(type(self)) + self.begin(cls) + + def pop_state(self): + ''' + Pop a lexer state from the stack + ''' + self.begin(self.__state_stack.pop()) + + def tokenize(self, text, lineno=1, index=0): + _ignored_tokens = _master_re = _ignore = _token_funcs = _literals = _remapping = None + + # --- Support for state changes + def _set_state(cls): + nonlocal _ignored_tokens, _master_re, _ignore, _token_funcs, _literals, _remapping + _ignored_tokens = cls._ignored_tokens + _master_re = cls._master_re + _ignore = cls.ignore + _token_funcs = cls._token_funcs + _literals = cls.literals + _remapping = cls._remapping + + self.__set_state = _set_state + _set_state(type(self)) + + # --- Support for backtracking + _mark_stack = [] + def _mark(): + _mark_stack.append((type(self), index, lineno)) + self.mark = _mark + + def _accept(): + _mark_stack.pop() + self.accept = _accept + + def _reject(): + nonlocal index, lineno + cls, index, lineno = _mark_stack[-1] + _set_state(cls) + self.reject = _reject + + + # --- Main tokenization function + self.text = text + try: + while True: + try: + if text[index] in _ignore: + index += 1 + continue + except IndexError: + return + + tok = Token() + tok.lineno = lineno + tok.index = index + m = _master_re.match(text, index) + if m: + tok.end = index = m.end() + tok.value = m.group() + tok.type = m.lastgroup + + if tok.type in _remapping: + tok.type = _remapping[tok.type].get(tok.value, tok.type) + + if tok.type in _token_funcs: + self.index = index + self.lineno = lineno + tok = _token_funcs[tok.type](self, tok) + index = self.index + lineno = self.lineno + if not tok: + continue + + if tok.type in _ignored_tokens: + continue + + yield tok + + else: + # No match, see if the character is in literals + if text[index] in _literals: + tok.value = text[index] + tok.end = index + 1 + tok.type = tok.value + index += 1 + yield tok + else: + # A lexing error + self.index = index + self.lineno = lineno + tok.type = 'ERROR' + tok.value = text[index:] + tok = self.error(tok) + if tok is not None: + tok.end = self.index + yield tok + + index = self.index + lineno = self.lineno + + # Set the final state of the lexer before exiting (even if exception) + finally: + self.text = text + self.index = index + self.lineno = lineno + + # Default implementations of the error handler. May be changed in subclasses + def error(self, t): + raise LexError(f'Illegal character {t.value[0]!r} at index {self.index}', t.value, self.index) diff --git a/sly/yacc.py b/sly/yacc.py new file mode 100644 index 0000000..aafd6af --- /dev/null +++ b/sly/yacc.py @@ -0,0 +1,2263 @@ +# ----------------------------------------------------------------------------- +# sly: yacc.py +# +# Copyright (C) 2016-2018 +# David M. Beazley (Dabeaz LLC) +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of the David Beazley or Dabeaz LLC may be used to +# endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- + +import sys +import inspect +from collections import OrderedDict, defaultdict, Counter + +__all__ = [ 'Parser' ] + +class YaccError(Exception): + ''' + Exception raised for yacc-related build errors. + ''' + pass + +#----------------------------------------------------------------------------- +# === User configurable parameters === +# +# Change these to modify the default behavior of yacc (if you wish). +# Move these parameters to the Yacc class itself. +#----------------------------------------------------------------------------- + +ERROR_COUNT = 3 # Number of symbols that must be shifted to leave recovery mode +MAXINT = sys.maxsize + +# This object is a stand-in for a logging object created by the +# logging module. SLY will use this by default to create things +# such as the parser.out file. If a user wants more detailed +# information, they can create their own logging object and pass +# it into SLY. + +class SlyLogger(object): + def __init__(self, f): + self.f = f + + def debug(self, msg, *args, **kwargs): + self.f.write((msg % args) + '\n') + + info = debug + + def warning(self, msg, *args, **kwargs): + self.f.write('WARNING: ' + (msg % args) + '\n') + + def error(self, msg, *args, **kwargs): + self.f.write('ERROR: ' + (msg % args) + '\n') + + critical = debug + + +# ---------------------------------------------------------------------- +# This class is used to hold non-terminal grammar symbols during parsing. +# It normally has the following attributes set: +# .type = Grammar symbol type +# .value = Symbol value +# .lineno = Starting line number +# .index = Starting lex position +# ---------------------------------------------------------------------- + +class YaccSymbol: + def __str__(self): + return self.type + + def __repr__(self): + return str(self) + +# ---------------------------------------------------------------------- +# This class is a wrapper around the objects actually passed to each +# grammar rule. Index lookup and assignment actually assign the +# .value attribute of the underlying YaccSymbol object. +# The lineno() method returns the line number of a given +# item (or 0 if not defined). +# ---------------------------------------------------------------------- + +class YaccProduction: + __slots__ = ('_slice', '_namemap', '_stack') + def __init__(self, s, stack=None): + self._slice = s + self._namemap = { } + self._stack = stack + + def __getitem__(self, n): + if n >= 0: + return self._slice[n].value + else: + return self._stack[n].value + + def __setitem__(self, n, v): + if n >= 0: + self._slice[n].value = v + else: + self._stack[n].value = v + + def __len__(self): + return len(self._slice) + + @property + def lineno(self): + for tok in self._slice: + lineno = getattr(tok, 'lineno', None) + if lineno: + return lineno + raise AttributeError('No line number found') + + @property + def index(self): + for tok in self._slice: + index = getattr(tok, 'index', None) + if index is not None: + return index + raise AttributeError('No index attribute found') + + @property + def end(self): + result = None + for tok in self._slice: + r = getattr(tok, 'end', None) + if r: + result = r + return result + + def __getattr__(self, name): + if name in self._namemap: + return self._namemap[name](self._slice) + else: + nameset = '{' + ', '.join(self._namemap) + '}' + raise AttributeError(f'No symbol {name}. Must be one of {nameset}.') + + def __setattr__(self, name, value): + if name[:1] == '_': + super().__setattr__(name, value) + else: + raise AttributeError(f"Can't reassign the value of attribute {name!r}") + +# ----------------------------------------------------------------------------- +# === Grammar Representation === +# +# The following functions, classes, and variables are used to represent and +# manipulate the rules that make up a grammar. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# class Production: +# +# This class stores the raw information about a single production or grammar rule. +# A grammar rule refers to a specification such as this: +# +# expr : expr PLUS term +# +# Here are the basic attributes defined on all productions +# +# name - Name of the production. For example 'expr' +# prod - A list of symbols on the right side ['expr','PLUS','term'] +# prec - Production precedence level +# number - Production number. +# func - Function that executes on reduce +# file - File where production function is defined +# lineno - Line number where production function is defined +# +# The following attributes are defined or optional. +# +# len - Length of the production (number of symbols on right hand side) +# usyms - Set of unique symbols found in the production +# ----------------------------------------------------------------------------- + +class Production(object): + reduced = 0 + def __init__(self, number, name, prod, precedence=('right', 0), func=None, file='', line=0): + self.name = name + self.prod = tuple(prod) + self.number = number + self.func = func + self.file = file + self.line = line + self.prec = precedence + + # Internal settings used during table construction + self.len = len(self.prod) # Length of the production + + # Create a list of unique production symbols used in the production + self.usyms = [] + symmap = defaultdict(list) + for n, s in enumerate(self.prod): + symmap[s].append(n) + if s not in self.usyms: + self.usyms.append(s) + + # Create a name mapping + # First determine (in advance) if there are duplicate names + namecount = defaultdict(int) + for key in self.prod: + namecount[key] += 1 + if key in _name_aliases: + for key in _name_aliases[key]: + namecount[key] += 1 + + # Now, walk through the names and generate accessor functions + nameuse = defaultdict(int) + namemap = { } + for index, key in enumerate(self.prod): + if namecount[key] > 1: + k = f'{key}{nameuse[key]}' + nameuse[key] += 1 + else: + k = key + namemap[k] = lambda s,i=index: s[i].value + if key in _name_aliases: + for n, alias in enumerate(_name_aliases[key]): + if namecount[alias] > 1: + k = f'{alias}{nameuse[alias]}' + nameuse[alias] += 1 + else: + k = alias + # The value is either a list (for repetition) or a tuple for optional + namemap[k] = lambda s,i=index,n=n: ([x[n] for x in s[i].value]) if isinstance(s[i].value, list) else s[i].value[n] + + self.namemap = namemap + + # List of all LR items for the production + self.lr_items = [] + self.lr_next = None + + def __str__(self): + if self.prod: + s = '%s -> %s' % (self.name, ' '.join(self.prod)) + else: + s = f'{self.name} -> ' + + if self.prec[1]: + s += ' [precedence=%s, level=%d]' % self.prec + + return s + + def __repr__(self): + return f'Production({self})' + + def __len__(self): + return len(self.prod) + + def __nonzero__(self): + raise RuntimeError('Used') + return 1 + + def __getitem__(self, index): + return self.prod[index] + + # Return the nth lr_item from the production (or None if at the end) + def lr_item(self, n): + if n > len(self.prod): + return None + p = LRItem(self, n) + # Precompute the list of productions immediately following. + try: + p.lr_after = Prodnames[p.prod[n+1]] + except (IndexError, KeyError): + p.lr_after = [] + try: + p.lr_before = p.prod[n-1] + except IndexError: + p.lr_before = None + return p + +# ----------------------------------------------------------------------------- +# class LRItem +# +# This class represents a specific stage of parsing a production rule. For +# example: +# +# expr : expr . PLUS term +# +# In the above, the "." represents the current location of the parse. Here +# basic attributes: +# +# name - Name of the production. For example 'expr' +# prod - A list of symbols on the right side ['expr','.', 'PLUS','term'] +# number - Production number. +# +# lr_next Next LR item. Example, if we are ' expr -> expr . PLUS term' +# then lr_next refers to 'expr -> expr PLUS . term' +# lr_index - LR item index (location of the ".") in the prod list. +# lookaheads - LALR lookahead symbols for this item +# len - Length of the production (number of symbols on right hand side) +# lr_after - List of all productions that immediately follow +# lr_before - Grammar symbol immediately before +# ----------------------------------------------------------------------------- + +class LRItem(object): + def __init__(self, p, n): + self.name = p.name + self.prod = list(p.prod) + self.number = p.number + self.lr_index = n + self.lookaheads = {} + self.prod.insert(n, '.') + self.prod = tuple(self.prod) + self.len = len(self.prod) + self.usyms = p.usyms + + def __str__(self): + if self.prod: + s = '%s -> %s' % (self.name, ' '.join(self.prod)) + else: + s = f'{self.name} -> ' + return s + + def __repr__(self): + return f'LRItem({self})' + +# ----------------------------------------------------------------------------- +# rightmost_terminal() +# +# Return the rightmost terminal from a list of symbols. Used in add_production() +# ----------------------------------------------------------------------------- +def rightmost_terminal(symbols, terminals): + i = len(symbols) - 1 + while i >= 0: + if symbols[i] in terminals: + return symbols[i] + i -= 1 + return None + +# ----------------------------------------------------------------------------- +# === GRAMMAR CLASS === +# +# The following class represents the contents of the specified grammar along +# with various computed properties such as first sets, follow sets, LR items, etc. +# This data is used for critical parts of the table generation process later. +# ----------------------------------------------------------------------------- + +class GrammarError(YaccError): + pass + +class Grammar(object): + def __init__(self, terminals): + self.Productions = [None] # A list of all of the productions. The first + # entry is always reserved for the purpose of + # building an augmented grammar + + self.Prodnames = {} # A dictionary mapping the names of nonterminals to a list of all + # productions of that nonterminal. + + self.Prodmap = {} # A dictionary that is only used to detect duplicate + # productions. + + self.Terminals = {} # A dictionary mapping the names of terminal symbols to a + # list of the rules where they are used. + + for term in terminals: + self.Terminals[term] = [] + + self.Terminals['error'] = [] + + self.Nonterminals = {} # A dictionary mapping names of nonterminals to a list + # of rule numbers where they are used. + + self.First = {} # A dictionary of precomputed FIRST(x) symbols + + self.Follow = {} # A dictionary of precomputed FOLLOW(x) symbols + + self.Precedence = {} # Precedence rules for each terminal. Contains tuples of the + # form ('right',level) or ('nonassoc', level) or ('left',level) + + self.UsedPrecedence = set() # Precedence rules that were actually used by the grammer. + # This is only used to provide error checking and to generate + # a warning about unused precedence rules. + + self.Start = None # Starting symbol for the grammar + + + def __len__(self): + return len(self.Productions) + + def __getitem__(self, index): + return self.Productions[index] + + # ----------------------------------------------------------------------------- + # set_precedence() + # + # Sets the precedence for a given terminal. assoc is the associativity such as + # 'left','right', or 'nonassoc'. level is a numeric level. + # + # ----------------------------------------------------------------------------- + + def set_precedence(self, term, assoc, level): + assert self.Productions == [None], 'Must call set_precedence() before add_production()' + if term in self.Precedence: + raise GrammarError(f'Precedence already specified for terminal {term!r}') + if assoc not in ['left', 'right', 'nonassoc']: + raise GrammarError(f"Associativity of {term!r} must be one of 'left','right', or 'nonassoc'") + self.Precedence[term] = (assoc, level) + + # ----------------------------------------------------------------------------- + # add_production() + # + # Given an action function, this function assembles a production rule and + # computes its precedence level. + # + # The production rule is supplied as a list of symbols. For example, + # a rule such as 'expr : expr PLUS term' has a production name of 'expr' and + # symbols ['expr','PLUS','term']. + # + # Precedence is determined by the precedence of the right-most non-terminal + # or the precedence of a terminal specified by %prec. + # + # A variety of error checks are performed to make sure production symbols + # are valid and that %prec is used correctly. + # ----------------------------------------------------------------------------- + + def add_production(self, prodname, syms, func=None, file='', line=0): + + if prodname in self.Terminals: + raise GrammarError(f'{file}:{line}: Illegal rule name {prodname!r}. Already defined as a token') + if prodname == 'error': + raise GrammarError(f'{file}:{line}: Illegal rule name {prodname!r}. error is a reserved word') + + # Look for literal tokens + for n, s in enumerate(syms): + if s[0] in "'\"" and s[0] == s[-1]: + c = s[1:-1] + if (len(c) != 1): + raise GrammarError(f'{file}:{line}: Literal token {s} in rule {prodname!r} may only be a single character') + if c not in self.Terminals: + self.Terminals[c] = [] + syms[n] = c + continue + + # Determine the precedence level + if '%prec' in syms: + if syms[-1] == '%prec': + raise GrammarError(f'{file}:{line}: Syntax error. Nothing follows %%prec') + if syms[-2] != '%prec': + raise GrammarError(f'{file}:{line}: Syntax error. %prec can only appear at the end of a grammar rule') + precname = syms[-1] + prodprec = self.Precedence.get(precname) + if not prodprec: + raise GrammarError(f'{file}:{line}: Nothing known about the precedence of {precname!r}') + else: + self.UsedPrecedence.add(precname) + del syms[-2:] # Drop %prec from the rule + else: + # If no %prec, precedence is determined by the rightmost terminal symbol + precname = rightmost_terminal(syms, self.Terminals) + prodprec = self.Precedence.get(precname, ('right', 0)) + + # See if the rule is already in the rulemap + map = '%s -> %s' % (prodname, syms) + if map in self.Prodmap: + m = self.Prodmap[map] + raise GrammarError(f'{file}:{line}: Duplicate rule {m}. ' + + f'Previous definition at {m.file}:{m.line}') + + # From this point on, everything is valid. Create a new Production instance + pnumber = len(self.Productions) + if prodname not in self.Nonterminals: + self.Nonterminals[prodname] = [] + + # Add the production number to Terminals and Nonterminals + for t in syms: + if t in self.Terminals: + self.Terminals[t].append(pnumber) + else: + if t not in self.Nonterminals: + self.Nonterminals[t] = [] + self.Nonterminals[t].append(pnumber) + + # Create a production and add it to the list of productions + p = Production(pnumber, prodname, syms, prodprec, func, file, line) + self.Productions.append(p) + self.Prodmap[map] = p + + # Add to the global productions list + try: + self.Prodnames[prodname].append(p) + except KeyError: + self.Prodnames[prodname] = [p] + + # ----------------------------------------------------------------------------- + # set_start() + # + # Sets the starting symbol and creates the augmented grammar. Production + # rule 0 is S' -> start where start is the start symbol. + # ----------------------------------------------------------------------------- + + def set_start(self, start=None): + if callable(start): + start = start.__name__ + + if not start: + start = self.Productions[1].name + + if start not in self.Nonterminals: + raise GrammarError(f'start symbol {start} undefined') + self.Productions[0] = Production(0, "S'", [start]) + self.Nonterminals[start].append(0) + self.Start = start + + # ----------------------------------------------------------------------------- + # find_unreachable() + # + # Find all of the nonterminal symbols that can't be reached from the starting + # symbol. Returns a list of nonterminals that can't be reached. + # ----------------------------------------------------------------------------- + + def find_unreachable(self): + + # Mark all symbols that are reachable from a symbol s + def mark_reachable_from(s): + if s in reachable: + return + reachable.add(s) + for p in self.Prodnames.get(s, []): + for r in p.prod: + mark_reachable_from(r) + + reachable = set() + mark_reachable_from(self.Productions[0].prod[0]) + return [s for s in self.Nonterminals if s not in reachable] + + # ----------------------------------------------------------------------------- + # infinite_cycles() + # + # This function looks at the various parsing rules and tries to detect + # infinite recursion cycles (grammar rules where there is no possible way + # to derive a string of only terminals). + # ----------------------------------------------------------------------------- + + def infinite_cycles(self): + terminates = {} + + # Terminals: + for t in self.Terminals: + terminates[t] = True + + terminates['$end'] = True + + # Nonterminals: + + # Initialize to false: + for n in self.Nonterminals: + terminates[n] = False + + # Then propagate termination until no change: + while True: + some_change = False + for (n, pl) in self.Prodnames.items(): + # Nonterminal n terminates iff any of its productions terminates. + for p in pl: + # Production p terminates iff all of its rhs symbols terminate. + for s in p.prod: + if not terminates[s]: + # The symbol s does not terminate, + # so production p does not terminate. + p_terminates = False + break + else: + # didn't break from the loop, + # so every symbol s terminates + # so production p terminates. + p_terminates = True + + if p_terminates: + # symbol n terminates! + if not terminates[n]: + terminates[n] = True + some_change = True + # Don't need to consider any more productions for this n. + break + + if not some_change: + break + + infinite = [] + for (s, term) in terminates.items(): + if not term: + if s not in self.Prodnames and s not in self.Terminals and s != 'error': + # s is used-but-not-defined, and we've already warned of that, + # so it would be overkill to say that it's also non-terminating. + pass + else: + infinite.append(s) + + return infinite + + # ----------------------------------------------------------------------------- + # undefined_symbols() + # + # Find all symbols that were used the grammar, but not defined as tokens or + # grammar rules. Returns a list of tuples (sym, prod) where sym in the symbol + # and prod is the production where the symbol was used. + # ----------------------------------------------------------------------------- + def undefined_symbols(self): + result = [] + for p in self.Productions: + if not p: + continue + + for s in p.prod: + if s not in self.Prodnames and s not in self.Terminals and s != 'error': + result.append((s, p)) + return result + + # ----------------------------------------------------------------------------- + # unused_terminals() + # + # Find all terminals that were defined, but not used by the grammar. Returns + # a list of all symbols. + # ----------------------------------------------------------------------------- + def unused_terminals(self): + unused_tok = [] + for s, v in self.Terminals.items(): + if s != 'error' and not v: + unused_tok.append(s) + + return unused_tok + + # ------------------------------------------------------------------------------ + # unused_rules() + # + # Find all grammar rules that were defined, but not used (maybe not reachable) + # Returns a list of productions. + # ------------------------------------------------------------------------------ + + def unused_rules(self): + unused_prod = [] + for s, v in self.Nonterminals.items(): + if not v: + p = self.Prodnames[s][0] + unused_prod.append(p) + return unused_prod + + # ----------------------------------------------------------------------------- + # unused_precedence() + # + # Returns a list of tuples (term,precedence) corresponding to precedence + # rules that were never used by the grammar. term is the name of the terminal + # on which precedence was applied and precedence is a string such as 'left' or + # 'right' corresponding to the type of precedence. + # ----------------------------------------------------------------------------- + + def unused_precedence(self): + unused = [] + for termname in self.Precedence: + if not (termname in self.Terminals or termname in self.UsedPrecedence): + unused.append((termname, self.Precedence[termname][0])) + + return unused + + # ------------------------------------------------------------------------- + # _first() + # + # Compute the value of FIRST1(beta) where beta is a tuple of symbols. + # + # During execution of compute_first1, the result may be incomplete. + # Afterward (e.g., when called from compute_follow()), it will be complete. + # ------------------------------------------------------------------------- + def _first(self, beta): + + # We are computing First(x1,x2,x3,...,xn) + result = [] + for x in beta: + x_produces_empty = False + + # Add all the non- symbols of First[x] to the result. + for f in self.First[x]: + if f == '': + x_produces_empty = True + else: + if f not in result: + result.append(f) + + if x_produces_empty: + # We have to consider the next x in beta, + # i.e. stay in the loop. + pass + else: + # We don't have to consider any further symbols in beta. + break + else: + # There was no 'break' from the loop, + # so x_produces_empty was true for all x in beta, + # so beta produces empty as well. + result.append('') + + return result + + # ------------------------------------------------------------------------- + # compute_first() + # + # Compute the value of FIRST1(X) for all symbols + # ------------------------------------------------------------------------- + def compute_first(self): + if self.First: + return self.First + + # Terminals: + for t in self.Terminals: + self.First[t] = [t] + + self.First['$end'] = ['$end'] + + # Nonterminals: + + # Initialize to the empty set: + for n in self.Nonterminals: + self.First[n] = [] + + # Then propagate symbols until no change: + while True: + some_change = False + for n in self.Nonterminals: + for p in self.Prodnames[n]: + for f in self._first(p.prod): + if f not in self.First[n]: + self.First[n].append(f) + some_change = True + if not some_change: + break + + return self.First + + # --------------------------------------------------------------------- + # compute_follow() + # + # Computes all of the follow sets for every non-terminal symbol. The + # follow set is the set of all symbols that might follow a given + # non-terminal. See the Dragon book, 2nd Ed. p. 189. + # --------------------------------------------------------------------- + def compute_follow(self, start=None): + # If already computed, return the result + if self.Follow: + return self.Follow + + # If first sets not computed yet, do that first. + if not self.First: + self.compute_first() + + # Add '$end' to the follow list of the start symbol + for k in self.Nonterminals: + self.Follow[k] = [] + + if not start: + start = self.Productions[1].name + + self.Follow[start] = ['$end'] + + while True: + didadd = False + for p in self.Productions[1:]: + # Here is the production set + for i, B in enumerate(p.prod): + if B in self.Nonterminals: + # Okay. We got a non-terminal in a production + fst = self._first(p.prod[i+1:]) + hasempty = False + for f in fst: + if f != '' and f not in self.Follow[B]: + self.Follow[B].append(f) + didadd = True + if f == '': + hasempty = True + if hasempty or i == (len(p.prod)-1): + # Add elements of follow(a) to follow(b) + for f in self.Follow[p.name]: + if f not in self.Follow[B]: + self.Follow[B].append(f) + didadd = True + if not didadd: + break + return self.Follow + + + # ----------------------------------------------------------------------------- + # build_lritems() + # + # This function walks the list of productions and builds a complete set of the + # LR items. The LR items are stored in two ways: First, they are uniquely + # numbered and placed in the list _lritems. Second, a linked list of LR items + # is built for each production. For example: + # + # E -> E PLUS E + # + # Creates the list + # + # [E -> . E PLUS E, E -> E . PLUS E, E -> E PLUS . E, E -> E PLUS E . ] + # ----------------------------------------------------------------------------- + + def build_lritems(self): + for p in self.Productions: + lastlri = p + i = 0 + lr_items = [] + while True: + if i > len(p): + lri = None + else: + lri = LRItem(p, i) + # Precompute the list of productions immediately following + try: + lri.lr_after = self.Prodnames[lri.prod[i+1]] + except (IndexError, KeyError): + lri.lr_after = [] + try: + lri.lr_before = lri.prod[i-1] + except IndexError: + lri.lr_before = None + + lastlri.lr_next = lri + if not lri: + break + lr_items.append(lri) + lastlri = lri + i += 1 + p.lr_items = lr_items + + + # ---------------------------------------------------------------------- + # Debugging output. Printing the grammar will produce a detailed + # description along with some diagnostics. + # ---------------------------------------------------------------------- + def __str__(self): + out = [] + out.append('Grammar:\n') + for n, p in enumerate(self.Productions): + out.append(f'Rule {n:<5d} {p}') + + unused_terminals = self.unused_terminals() + if unused_terminals: + out.append('\nUnused terminals:\n') + for term in unused_terminals: + out.append(f' {term}') + + out.append('\nTerminals, with rules where they appear:\n') + for term in sorted(self.Terminals): + out.append('%-20s : %s' % (term, ' '.join(str(s) for s in self.Terminals[term]))) + + out.append('\nNonterminals, with rules where they appear:\n') + for nonterm in sorted(self.Nonterminals): + out.append('%-20s : %s' % (nonterm, ' '.join(str(s) for s in self.Nonterminals[nonterm]))) + + out.append('') + return '\n'.join(out) + +# ----------------------------------------------------------------------------- +# === LR Generator === +# +# The following classes and functions are used to generate LR parsing tables on +# a grammar. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# digraph() +# traverse() +# +# The following two functions are used to compute set valued functions +# of the form: +# +# F(x) = F'(x) U U{F(y) | x R y} +# +# This is used to compute the values of Read() sets as well as FOLLOW sets +# in LALR(1) generation. +# +# Inputs: X - An input set +# R - A relation +# FP - Set-valued function +# ------------------------------------------------------------------------------ + +def digraph(X, R, FP): + N = {} + for x in X: + N[x] = 0 + stack = [] + F = {} + for x in X: + if N[x] == 0: + traverse(x, N, stack, F, X, R, FP) + return F + +def traverse(x, N, stack, F, X, R, FP): + stack.append(x) + d = len(stack) + N[x] = d + F[x] = FP(x) # F(X) <- F'(x) + + rel = R(x) # Get y's related to x + for y in rel: + if N[y] == 0: + traverse(y, N, stack, F, X, R, FP) + N[x] = min(N[x], N[y]) + for a in F.get(y, []): + if a not in F[x]: + F[x].append(a) + if N[x] == d: + N[stack[-1]] = MAXINT + F[stack[-1]] = F[x] + element = stack.pop() + while element != x: + N[stack[-1]] = MAXINT + F[stack[-1]] = F[x] + element = stack.pop() + +class LALRError(YaccError): + pass + +# ----------------------------------------------------------------------------- +# == LRGeneratedTable == +# +# This class implements the LR table generation algorithm. There are no +# public methods except for write() +# ----------------------------------------------------------------------------- + +class LRTable(object): + def __init__(self, grammar): + self.grammar = grammar + + # Internal attributes + self.lr_action = {} # Action table + self.lr_goto = {} # Goto table + self.lr_productions = grammar.Productions # Copy of grammar Production array + self.lr_goto_cache = {} # Cache of computed gotos + self.lr0_cidhash = {} # Cache of closures + self._add_count = 0 # Internal counter used to detect cycles + + # Diagonistic information filled in by the table generator + self.state_descriptions = OrderedDict() + self.sr_conflict = 0 + self.rr_conflict = 0 + self.conflicts = [] # List of conflicts + + self.sr_conflicts = [] + self.rr_conflicts = [] + + # Build the tables + self.grammar.build_lritems() + self.grammar.compute_first() + self.grammar.compute_follow() + self.lr_parse_table() + + # Build default states + # This identifies parser states where there is only one possible reduction action. + # For such states, the parser can make a choose to make a rule reduction without consuming + # the next look-ahead token. This delayed invocation of the tokenizer can be useful in + # certain kinds of advanced parsing situations where the lexer and parser interact with + # each other or change states (i.e., manipulation of scope, lexer states, etc.). + # + # See: http://www.gnu.org/software/bison/manual/html_node/Default-Reductions.html#Default-Reductions + self.defaulted_states = {} + for state, actions in self.lr_action.items(): + rules = list(actions.values()) + if len(rules) == 1 and rules[0] < 0: + self.defaulted_states[state] = rules[0] + + # Compute the LR(0) closure operation on I, where I is a set of LR(0) items. + def lr0_closure(self, I): + self._add_count += 1 + + # Add everything in I to J + J = I[:] + didadd = True + while didadd: + didadd = False + for j in J: + for x in j.lr_after: + if getattr(x, 'lr0_added', 0) == self._add_count: + continue + # Add B --> .G to J + J.append(x.lr_next) + x.lr0_added = self._add_count + didadd = True + + return J + + # Compute the LR(0) goto function goto(I,X) where I is a set + # of LR(0) items and X is a grammar symbol. This function is written + # in a way that guarantees uniqueness of the generated goto sets + # (i.e. the same goto set will never be returned as two different Python + # objects). With uniqueness, we can later do fast set comparisons using + # id(obj) instead of element-wise comparison. + + def lr0_goto(self, I, x): + # First we look for a previously cached entry + g = self.lr_goto_cache.get((id(I), x)) + if g: + return g + + # Now we generate the goto set in a way that guarantees uniqueness + # of the result + + s = self.lr_goto_cache.get(x) + if not s: + s = {} + self.lr_goto_cache[x] = s + + gs = [] + for p in I: + n = p.lr_next + if n and n.lr_before == x: + s1 = s.get(id(n)) + if not s1: + s1 = {} + s[id(n)] = s1 + gs.append(n) + s = s1 + g = s.get('$end') + if not g: + if gs: + g = self.lr0_closure(gs) + s['$end'] = g + else: + s['$end'] = gs + self.lr_goto_cache[(id(I), x)] = g + return g + + # Compute the LR(0) sets of item function + def lr0_items(self): + C = [self.lr0_closure([self.grammar.Productions[0].lr_next])] + i = 0 + for I in C: + self.lr0_cidhash[id(I)] = i + i += 1 + + # Loop over the items in C and each grammar symbols + i = 0 + while i < len(C): + I = C[i] + i += 1 + + # Collect all of the symbols that could possibly be in the goto(I,X) sets + asyms = {} + for ii in I: + for s in ii.usyms: + asyms[s] = None + + for x in asyms: + g = self.lr0_goto(I, x) + if not g or id(g) in self.lr0_cidhash: + continue + self.lr0_cidhash[id(g)] = len(C) + C.append(g) + + return C + + # ----------------------------------------------------------------------------- + # ==== LALR(1) Parsing ==== + # + # LALR(1) parsing is almost exactly the same as SLR except that instead of + # relying upon Follow() sets when performing reductions, a more selective + # lookahead set that incorporates the state of the LR(0) machine is utilized. + # Thus, we mainly just have to focus on calculating the lookahead sets. + # + # The method used here is due to DeRemer and Pennelo (1982). + # + # DeRemer, F. L., and T. J. Pennelo: "Efficient Computation of LALR(1) + # Lookahead Sets", ACM Transactions on Programming Languages and Systems, + # Vol. 4, No. 4, Oct. 1982, pp. 615-649 + # + # Further details can also be found in: + # + # J. Tremblay and P. Sorenson, "The Theory and Practice of Compiler Writing", + # McGraw-Hill Book Company, (1985). + # + # ----------------------------------------------------------------------------- + + # ----------------------------------------------------------------------------- + # compute_nullable_nonterminals() + # + # Creates a dictionary containing all of the non-terminals that might produce + # an empty production. + # ----------------------------------------------------------------------------- + + def compute_nullable_nonterminals(self): + nullable = set() + num_nullable = 0 + while True: + for p in self.grammar.Productions[1:]: + if p.len == 0: + nullable.add(p.name) + continue + for t in p.prod: + if t not in nullable: + break + else: + nullable.add(p.name) + if len(nullable) == num_nullable: + break + num_nullable = len(nullable) + return nullable + + # ----------------------------------------------------------------------------- + # find_nonterminal_trans(C) + # + # Given a set of LR(0) items, this functions finds all of the non-terminal + # transitions. These are transitions in which a dot appears immediately before + # a non-terminal. Returns a list of tuples of the form (state,N) where state + # is the state number and N is the nonterminal symbol. + # + # The input C is the set of LR(0) items. + # ----------------------------------------------------------------------------- + + def find_nonterminal_transitions(self, C): + trans = [] + for stateno, state in enumerate(C): + for p in state: + if p.lr_index < p.len - 1: + t = (stateno, p.prod[p.lr_index+1]) + if t[1] in self.grammar.Nonterminals: + if t not in trans: + trans.append(t) + return trans + + # ----------------------------------------------------------------------------- + # dr_relation() + # + # Computes the DR(p,A) relationships for non-terminal transitions. The input + # is a tuple (state,N) where state is a number and N is a nonterminal symbol. + # + # Returns a list of terminals. + # ----------------------------------------------------------------------------- + + def dr_relation(self, C, trans, nullable): + dr_set = {} + state, N = trans + terms = [] + + g = self.lr0_goto(C[state], N) + for p in g: + if p.lr_index < p.len - 1: + a = p.prod[p.lr_index+1] + if a in self.grammar.Terminals: + if a not in terms: + terms.append(a) + + # This extra bit is to handle the start state + if state == 0 and N == self.grammar.Productions[0].prod[0]: + terms.append('$end') + + return terms + + # ----------------------------------------------------------------------------- + # reads_relation() + # + # Computes the READS() relation (p,A) READS (t,C). + # ----------------------------------------------------------------------------- + + def reads_relation(self, C, trans, empty): + # Look for empty transitions + rel = [] + state, N = trans + + g = self.lr0_goto(C[state], N) + j = self.lr0_cidhash.get(id(g), -1) + for p in g: + if p.lr_index < p.len - 1: + a = p.prod[p.lr_index + 1] + if a in empty: + rel.append((j, a)) + + return rel + + # ----------------------------------------------------------------------------- + # compute_lookback_includes() + # + # Determines the lookback and includes relations + # + # LOOKBACK: + # + # This relation is determined by running the LR(0) state machine forward. + # For example, starting with a production "N : . A B C", we run it forward + # to obtain "N : A B C ." We then build a relationship between this final + # state and the starting state. These relationships are stored in a dictionary + # lookdict. + # + # INCLUDES: + # + # Computes the INCLUDE() relation (p,A) INCLUDES (p',B). + # + # This relation is used to determine non-terminal transitions that occur + # inside of other non-terminal transition states. (p,A) INCLUDES (p', B) + # if the following holds: + # + # B -> LAT, where T -> epsilon and p' -L-> p + # + # L is essentially a prefix (which may be empty), T is a suffix that must be + # able to derive an empty string. State p' must lead to state p with the string L. + # + # ----------------------------------------------------------------------------- + + def compute_lookback_includes(self, C, trans, nullable): + lookdict = {} # Dictionary of lookback relations + includedict = {} # Dictionary of include relations + + # Make a dictionary of non-terminal transitions + dtrans = {} + for t in trans: + dtrans[t] = 1 + + # Loop over all transitions and compute lookbacks and includes + for state, N in trans: + lookb = [] + includes = [] + for p in C[state]: + if p.name != N: + continue + + # Okay, we have a name match. We now follow the production all the way + # through the state machine until we get the . on the right hand side + + lr_index = p.lr_index + j = state + while lr_index < p.len - 1: + lr_index = lr_index + 1 + t = p.prod[lr_index] + + # Check to see if this symbol and state are a non-terminal transition + if (j, t) in dtrans: + # Yes. Okay, there is some chance that this is an includes relation + # the only way to know for certain is whether the rest of the + # production derives empty + + li = lr_index + 1 + while li < p.len: + if p.prod[li] in self.grammar.Terminals: + break # No forget it + if p.prod[li] not in nullable: + break + li = li + 1 + else: + # Appears to be a relation between (j,t) and (state,N) + includes.append((j, t)) + + g = self.lr0_goto(C[j], t) # Go to next set + j = self.lr0_cidhash.get(id(g), -1) # Go to next state + + # When we get here, j is the final state, now we have to locate the production + for r in C[j]: + if r.name != p.name: + continue + if r.len != p.len: + continue + i = 0 + # This look is comparing a production ". A B C" with "A B C ." + while i < r.lr_index: + if r.prod[i] != p.prod[i+1]: + break + i = i + 1 + else: + lookb.append((j, r)) + for i in includes: + if i not in includedict: + includedict[i] = [] + includedict[i].append((state, N)) + lookdict[(state, N)] = lookb + + return lookdict, includedict + + # ----------------------------------------------------------------------------- + # compute_read_sets() + # + # Given a set of LR(0) items, this function computes the read sets. + # + # Inputs: C = Set of LR(0) items + # ntrans = Set of nonterminal transitions + # nullable = Set of empty transitions + # + # Returns a set containing the read sets + # ----------------------------------------------------------------------------- + + def compute_read_sets(self, C, ntrans, nullable): + FP = lambda x: self.dr_relation(C, x, nullable) + R = lambda x: self.reads_relation(C, x, nullable) + F = digraph(ntrans, R, FP) + return F + + # ----------------------------------------------------------------------------- + # compute_follow_sets() + # + # Given a set of LR(0) items, a set of non-terminal transitions, a readset, + # and an include set, this function computes the follow sets + # + # Follow(p,A) = Read(p,A) U U {Follow(p',B) | (p,A) INCLUDES (p',B)} + # + # Inputs: + # ntrans = Set of nonterminal transitions + # readsets = Readset (previously computed) + # inclsets = Include sets (previously computed) + # + # Returns a set containing the follow sets + # ----------------------------------------------------------------------------- + + def compute_follow_sets(self, ntrans, readsets, inclsets): + FP = lambda x: readsets[x] + R = lambda x: inclsets.get(x, []) + F = digraph(ntrans, R, FP) + return F + + # ----------------------------------------------------------------------------- + # add_lookaheads() + # + # Attaches the lookahead symbols to grammar rules. + # + # Inputs: lookbacks - Set of lookback relations + # followset - Computed follow set + # + # This function directly attaches the lookaheads to productions contained + # in the lookbacks set + # ----------------------------------------------------------------------------- + + def add_lookaheads(self, lookbacks, followset): + for trans, lb in lookbacks.items(): + # Loop over productions in lookback + for state, p in lb: + if state not in p.lookaheads: + p.lookaheads[state] = set() + + f = followset.get(trans, []) + for a in f: + p.lookaheads[state].add(a) + + # ----------------------------------------------------------------------------- + # add_lalr_lookaheads() + # + # This function does all of the work of adding lookahead information for use + # with LALR parsing + # ----------------------------------------------------------------------------- + + def add_lalr_lookaheads(self, C): + # Determine all of the nullable nonterminals + nullable = self.compute_nullable_nonterminals() + + # Find all non-terminal transitions + trans = self.find_nonterminal_transitions(C) + + # Compute read sets + readsets = self.compute_read_sets(C, trans, nullable) + + # Compute lookback/includes relations + lookd, included = self.compute_lookback_includes(C, trans, nullable) + + # Compute LALR FOLLOW sets + followsets = self.compute_follow_sets(trans, readsets, included) + + # Add all of the lookaheads + self.add_lookaheads(lookd, followsets) + + # ----------------------------------------------------------------------------- + # lr_parse_table() + # + # This function constructs the final LALR parse table. Touch this code and die. + # ----------------------------------------------------------------------------- + def lr_parse_table(self): + Productions = self.grammar.Productions + Precedence = self.grammar.Precedence + goto = self.lr_goto # Goto array + action = self.lr_action # Action array + + actionp = {} # Action production array (temporary) + + # Step 1: Construct C = { I0, I1, ... IN}, collection of LR(0) items + # This determines the number of states + + C = self.lr0_items() + self.add_lalr_lookaheads(C) + + # Build the parser table, state by state + for st, I in enumerate(C): + descrip = [] + # Loop over each production in I + actlist = [] # List of actions + st_action = {} + st_actionp = {} + st_goto = {} + + descrip.append(f'\nstate {st}\n') + for p in I: + descrip.append(f' ({p.number}) {p}') + + for p in I: + if p.len == p.lr_index + 1: + if p.name == "S'": + # Start symbol. Accept! + st_action['$end'] = 0 + st_actionp['$end'] = p + else: + # We are at the end of a production. Reduce! + laheads = p.lookaheads[st] + for a in laheads: + actlist.append((a, p, f'reduce using rule {p.number} ({p})')) + r = st_action.get(a) + if r is not None: + # Have a shift/reduce or reduce/reduce conflict + if r > 0: + # Need to decide on shift or reduce here + # By default we favor shifting. Need to add + # some precedence rules here. + + # Shift precedence comes from the token + sprec, slevel = Precedence.get(a, ('right', 0)) + + # Reduce precedence comes from rule being reduced (p) + rprec, rlevel = Productions[p.number].prec + + if (slevel < rlevel) or ((slevel == rlevel) and (rprec == 'left')): + # We really need to reduce here. + st_action[a] = -p.number + st_actionp[a] = p + if not slevel and not rlevel: + descrip.append(f' ! shift/reduce conflict for {a} resolved as reduce') + self.sr_conflicts.append((st, a, 'reduce')) + Productions[p.number].reduced += 1 + elif (slevel == rlevel) and (rprec == 'nonassoc'): + st_action[a] = None + else: + # Hmmm. Guess we'll keep the shift + if not rlevel: + descrip.append(f' ! shift/reduce conflict for {a} resolved as shift') + self.sr_conflicts.append((st, a, 'shift')) + elif r <= 0: + # Reduce/reduce conflict. In this case, we favor the rule + # that was defined first in the grammar file + oldp = Productions[-r] + pp = Productions[p.number] + if oldp.line > pp.line: + st_action[a] = -p.number + st_actionp[a] = p + chosenp, rejectp = pp, oldp + Productions[p.number].reduced += 1 + Productions[oldp.number].reduced -= 1 + else: + chosenp, rejectp = oldp, pp + self.rr_conflicts.append((st, chosenp, rejectp)) + descrip.append(' ! reduce/reduce conflict for %s resolved using rule %d (%s)' % + (a, st_actionp[a].number, st_actionp[a])) + else: + raise LALRError(f'Unknown conflict in state {st}') + else: + st_action[a] = -p.number + st_actionp[a] = p + Productions[p.number].reduced += 1 + else: + i = p.lr_index + a = p.prod[i+1] # Get symbol right after the "." + if a in self.grammar.Terminals: + g = self.lr0_goto(I, a) + j = self.lr0_cidhash.get(id(g), -1) + if j >= 0: + # We are in a shift state + actlist.append((a, p, f'shift and go to state {j}')) + r = st_action.get(a) + if r is not None: + # Whoa have a shift/reduce or shift/shift conflict + if r > 0: + if r != j: + raise LALRError(f'Shift/shift conflict in state {st}') + elif r <= 0: + # Do a precedence check. + # - if precedence of reduce rule is higher, we reduce. + # - if precedence of reduce is same and left assoc, we reduce. + # - otherwise we shift + rprec, rlevel = Productions[st_actionp[a].number].prec + sprec, slevel = Precedence.get(a, ('right', 0)) + if (slevel > rlevel) or ((slevel == rlevel) and (rprec == 'right')): + # We decide to shift here... highest precedence to shift + Productions[st_actionp[a].number].reduced -= 1 + st_action[a] = j + st_actionp[a] = p + if not rlevel: + descrip.append(f' ! shift/reduce conflict for {a} resolved as shift') + self.sr_conflicts.append((st, a, 'shift')) + elif (slevel == rlevel) and (rprec == 'nonassoc'): + st_action[a] = None + else: + # Hmmm. Guess we'll keep the reduce + if not slevel and not rlevel: + descrip.append(f' ! shift/reduce conflict for {a} resolved as reduce') + self.sr_conflicts.append((st, a, 'reduce')) + + else: + raise LALRError(f'Unknown conflict in state {st}') + else: + st_action[a] = j + st_actionp[a] = p + + # Print the actions associated with each terminal + _actprint = {} + for a, p, m in actlist: + if a in st_action: + if p is st_actionp[a]: + descrip.append(f' {a:<15s} {m}') + _actprint[(a, m)] = 1 + descrip.append('') + + # Construct the goto table for this state + nkeys = {} + for ii in I: + for s in ii.usyms: + if s in self.grammar.Nonterminals: + nkeys[s] = None + for n in nkeys: + g = self.lr0_goto(I, n) + j = self.lr0_cidhash.get(id(g), -1) + if j >= 0: + st_goto[n] = j + descrip.append(f' {n:<30s} shift and go to state {j}') + + action[st] = st_action + actionp[st] = st_actionp + goto[st] = st_goto + self.state_descriptions[st] = '\n'.join(descrip) + + # ---------------------------------------------------------------------- + # Debugging output. Printing the LRTable object will produce a listing + # of all of the states, conflicts, and other details. + # ---------------------------------------------------------------------- + def __str__(self): + out = [] + for descrip in self.state_descriptions.values(): + out.append(descrip) + + if self.sr_conflicts or self.rr_conflicts: + out.append('\nConflicts:\n') + + for state, tok, resolution in self.sr_conflicts: + out.append(f'shift/reduce conflict for {tok} in state {state} resolved as {resolution}') + + already_reported = set() + for state, rule, rejected in self.rr_conflicts: + if (state, id(rule), id(rejected)) in already_reported: + continue + out.append(f'reduce/reduce conflict in state {state} resolved using rule {rule}') + out.append(f'rejected rule ({rejected}) in state {state}') + already_reported.add((state, id(rule), id(rejected))) + + warned_never = set() + for state, rule, rejected in self.rr_conflicts: + if not rejected.reduced and (rejected not in warned_never): + out.append(f'Rule ({rejected}) is never reduced') + warned_never.add(rejected) + + return '\n'.join(out) + +# Collect grammar rules from a function +def _collect_grammar_rules(func): + grammar = [] + while func: + prodname = func.__name__ + unwrapped = inspect.unwrap(func) + filename = unwrapped.__code__.co_filename + lineno = unwrapped.__code__.co_firstlineno + for rule, lineno in zip(func.rules, range(lineno+len(func.rules)-1, 0, -1)): + syms = rule.split() + ebnf_prod = [] + while ('{' in syms) or ('[' in syms): + for s in syms: + if s == '[': + syms, prod = _replace_ebnf_optional(syms) + ebnf_prod.extend(prod) + break + elif s == '{': + syms, prod = _replace_ebnf_repeat(syms) + ebnf_prod.extend(prod) + break + elif '|' in s: + syms, prod = _replace_ebnf_choice(syms) + ebnf_prod.extend(prod) + break + + if syms[1:2] == [':'] or syms[1:2] == ['::=']: + grammar.append((func, filename, lineno, syms[0], syms[2:])) + else: + grammar.append((func, filename, lineno, prodname, syms)) + grammar.extend(ebnf_prod) + + func = getattr(func, 'next_func', None) + + return grammar + +# Replace EBNF repetition +def _replace_ebnf_repeat(syms): + syms = list(syms) + first = syms.index('{') + end = syms.index('}', first) + + # Look for choices inside + repeated_syms = syms[first+1:end] + if any('|' in sym for sym in repeated_syms): + repeated_syms, prods = _replace_ebnf_choice(repeated_syms) + else: + prods = [] + + symname, moreprods = _generate_repeat_rules(repeated_syms) + syms[first:end+1] = [symname] + return syms, prods + moreprods + +def _replace_ebnf_optional(syms): + syms = list(syms) + first = syms.index('[') + end = syms.index(']', first) + symname, prods = _generate_optional_rules(syms[first+1:end]) + syms[first:end+1] = [symname] + return syms, prods + +def _replace_ebnf_choice(syms): + syms = list(syms) + newprods = [ ] + n = 0 + while n < len(syms): + if '|' in syms[n]: + symname, prods = _generate_choice_rules(syms[n].split('|')) + syms[n] = symname + newprods.extend(prods) + n += 1 + return syms, newprods + +# Generate grammar rules for repeated items +_gencount = 0 + +# Dictionary mapping name aliases generated by EBNF rules. + +_name_aliases = { } + +def _sanitize_symbols(symbols): + for sym in symbols: + if sym.startswith("'"): + yield str(hex(ord(sym[1]))) + elif sym.isidentifier(): + yield sym + else: + yield sym.encode('utf-8').hex() + +def _generate_repeat_rules(symbols): + ''' + Symbols is a list of grammar symbols [ symbols ]. This + generates code corresponding to these grammar construction: + + @('repeat : many') + def repeat(self, p): + return p.many + + @('repeat :') + def repeat(self, p): + return [] + + @('many : many symbols') + def many(self, p): + p.many.append(symbols) + return p.many + + @('many : symbols') + def many(self, p): + return [ p.symbols ] + ''' + global _gencount + _gencount += 1 + basename = f'_{_gencount}_' + '_'.join(_sanitize_symbols(symbols)) + name = f'{basename}_repeat' + oname = f'{basename}_items' + iname = f'{basename}_item' + symtext = ' '.join(symbols) + + _name_aliases[name] = symbols + + productions = [ ] + _ = _decorator + + @_(f'{name} : {oname}') + def repeat(self, p): + return getattr(p, oname) + + @_(f'{name} : ') + def repeat2(self, p): + return [] + productions.extend(_collect_grammar_rules(repeat)) + productions.extend(_collect_grammar_rules(repeat2)) + + @_(f'{oname} : {oname} {iname}') + def many(self, p): + items = getattr(p, oname) + items.append(getattr(p, iname)) + return items + + @_(f'{oname} : {iname}') + def many2(self, p): + return [ getattr(p, iname) ] + + productions.extend(_collect_grammar_rules(many)) + productions.extend(_collect_grammar_rules(many2)) + + @_(f'{iname} : {symtext}') + def item(self, p): + return tuple(p) + + productions.extend(_collect_grammar_rules(item)) + return name, productions + +def _generate_optional_rules(symbols): + ''' + Symbols is a list of grammar symbols [ symbols ]. This + generates code corresponding to these grammar construction: + + @('optional : symbols') + def optional(self, p): + return p.symbols + + @('optional :') + def optional(self, p): + return None + ''' + global _gencount + _gencount += 1 + basename = f'_{_gencount}_' + '_'.join(_sanitize_symbols(symbols)) + name = f'{basename}_optional' + symtext = ' '.join(symbols) + + _name_aliases[name] = symbols + + productions = [ ] + _ = _decorator + + no_values = (None,) * len(symbols) + + @_(f'{name} : {symtext}') + def optional(self, p): + return tuple(p) + + @_(f'{name} : ') + def optional2(self, p): + return no_values + + productions.extend(_collect_grammar_rules(optional)) + productions.extend(_collect_grammar_rules(optional2)) + return name, productions + +def _generate_choice_rules(symbols): + ''' + Symbols is a list of grammar symbols such as [ 'PLUS', 'MINUS' ]. + This generates code corresponding to the following construction: + + @('PLUS', 'MINUS') + def choice(self, p): + return p[0] + ''' + global _gencount + _gencount += 1 + basename = f'_{_gencount}_' + '_'.join(_sanitize_symbols(symbols)) + name = f'{basename}_choice' + + _ = _decorator + productions = [ ] + + + def choice(self, p): + return p[0] + choice.__name__ = name + choice = _(*symbols)(choice) + productions.extend(_collect_grammar_rules(choice)) + return name, productions + +class ParserMetaDict(dict): + ''' + Dictionary that allows decorated grammar rule functions to be overloaded + ''' + def __setitem__(self, key, value): + if key in self and callable(value) and hasattr(value, 'rules'): + value.next_func = self[key] + if not hasattr(value.next_func, 'rules'): + raise GrammarError(f'Redefinition of {key}. Perhaps an earlier {key} is missing @_') + super().__setitem__(key, value) + + def __getitem__(self, key): + if key not in self and key.isupper() and key[:1] != '_': + return key.upper() + else: + return super().__getitem__(key) + +def _decorator(rule, *extra): + rules = [rule, *extra] + def decorate(func): + func.rules = [ *getattr(func, 'rules', []), *rules[::-1] ] + return func + return decorate + +class ParserMeta(type): + @classmethod + def __prepare__(meta, *args, **kwargs): + d = ParserMetaDict() + d['_'] = _decorator + return d + + def __new__(meta, clsname, bases, attributes): + del attributes['_'] + cls = super().__new__(meta, clsname, bases, attributes) + cls._build(list(attributes.items())) + return cls + +class Parser(metaclass=ParserMeta): + # Automatic tracking of position information + track_positions = True + + # Logging object where debugging/diagnostic messages are sent + log = SlyLogger(sys.stderr) + + # Debugging filename where parsetab.out data can be written + debugfile = None + + @classmethod + def __validate_tokens(cls): + if not hasattr(cls, 'tokens'): + cls.log.error('No token list is defined') + return False + + if not cls.tokens: + cls.log.error('tokens is empty') + return False + + if 'error' in cls.tokens: + cls.log.error("Illegal token name 'error'. Is a reserved word") + return False + + return True + + @classmethod + def __validate_precedence(cls): + if not hasattr(cls, 'precedence'): + cls.__preclist = [] + return True + + preclist = [] + if not isinstance(cls.precedence, (list, tuple)): + cls.log.error('precedence must be a list or tuple') + return False + + for level, p in enumerate(cls.precedence, start=1): + if not isinstance(p, (list, tuple)): + cls.log.error(f'Bad precedence table entry {p!r}. Must be a list or tuple') + return False + + if len(p) < 2: + cls.log.error(f'Malformed precedence entry {p!r}. Must be (assoc, term, ..., term)') + return False + + if not all(isinstance(term, str) for term in p): + cls.log.error('precedence items must be strings') + return False + + assoc = p[0] + preclist.extend((term, assoc, level) for term in p[1:]) + + cls.__preclist = preclist + return True + + @classmethod + def __validate_specification(cls): + ''' + Validate various parts of the grammar specification + ''' + if not cls.__validate_tokens(): + return False + if not cls.__validate_precedence(): + return False + return True + + @classmethod + def __build_grammar(cls, rules): + ''' + Build the grammar from the grammar rules + ''' + grammar_rules = [] + errors = '' + # Check for non-empty symbols + if not rules: + raise YaccError('No grammar rules are defined') + + grammar = Grammar(cls.tokens) + + # Set the precedence level for terminals + for term, assoc, level in cls.__preclist: + try: + grammar.set_precedence(term, assoc, level) + except GrammarError as e: + errors += f'{e}\n' + + for name, func in rules: + try: + parsed_rule = _collect_grammar_rules(func) + for pfunc, rulefile, ruleline, prodname, syms in parsed_rule: + try: + grammar.add_production(prodname, syms, pfunc, rulefile, ruleline) + except GrammarError as e: + errors += f'{e}\n' + except SyntaxError as e: + errors += f'{e}\n' + try: + grammar.set_start(getattr(cls, 'start', None)) + except GrammarError as e: + errors += f'{e}\n' + + undefined_symbols = grammar.undefined_symbols() + for sym, prod in undefined_symbols: + errors += '%s:%d: Symbol %r used, but not defined as a token or a rule\n' % (prod.file, prod.line, sym) + + unused_terminals = grammar.unused_terminals() + if unused_terminals: + unused_str = '{' + ','.join(unused_terminals) + '}' + cls.log.warning(f'Token{"(s)" if len(unused_terminals) >1 else ""} {unused_str} defined, but not used') + + unused_rules = grammar.unused_rules() + for prod in unused_rules: + cls.log.warning('%s:%d: Rule %r defined, but not used', prod.file, prod.line, prod.name) + + if len(unused_terminals) == 1: + cls.log.warning('There is 1 unused token') + if len(unused_terminals) > 1: + cls.log.warning('There are %d unused tokens', len(unused_terminals)) + + if len(unused_rules) == 1: + cls.log.warning('There is 1 unused rule') + if len(unused_rules) > 1: + cls.log.warning('There are %d unused rules', len(unused_rules)) + + unreachable = grammar.find_unreachable() + for u in unreachable: + cls.log.warning('Symbol %r is unreachable', u) + + if len(undefined_symbols) == 0: + infinite = grammar.infinite_cycles() + for inf in infinite: + errors += 'Infinite recursion detected for symbol %r\n' % inf + + unused_prec = grammar.unused_precedence() + for term, assoc in unused_prec: + errors += 'Precedence rule %r defined for unknown symbol %r\n' % (assoc, term) + + cls._grammar = grammar + if errors: + raise YaccError('Unable to build grammar.\n'+errors) + + @classmethod + def __build_lrtables(cls): + ''' + Build the LR Parsing tables from the grammar + ''' + lrtable = LRTable(cls._grammar) + num_sr = len(lrtable.sr_conflicts) + + # Report shift/reduce and reduce/reduce conflicts + if num_sr != getattr(cls, 'expected_shift_reduce', None): + if num_sr == 1: + cls.log.warning('1 shift/reduce conflict') + elif num_sr > 1: + cls.log.warning('%d shift/reduce conflicts', num_sr) + + num_rr = len(lrtable.rr_conflicts) + if num_rr != getattr(cls, 'expected_reduce_reduce', None): + if num_rr == 1: + cls.log.warning('1 reduce/reduce conflict') + elif num_rr > 1: + cls.log.warning('%d reduce/reduce conflicts', num_rr) + + cls._lrtable = lrtable + return True + + @classmethod + def __collect_rules(cls, definitions): + ''' + Collect all of the tagged grammar rules + ''' + rules = [ (name, value) for name, value in definitions + if callable(value) and hasattr(value, 'rules') ] + return rules + + # ---------------------------------------------------------------------- + # Build the LALR(1) tables. definitions is a list of (name, item) tuples + # of all definitions provided in the class, listed in the order in which + # they were defined. This method is triggered by a metaclass. + # ---------------------------------------------------------------------- + @classmethod + def _build(cls, definitions): + if vars(cls).get('_build', False): + return + + # Collect all of the grammar rules from the class definition + rules = cls.__collect_rules(definitions) + + # Validate other parts of the grammar specification + if not cls.__validate_specification(): + raise YaccError('Invalid parser specification') + + # Build the underlying grammar object + cls.__build_grammar(rules) + + # Build the LR tables + if not cls.__build_lrtables(): + raise YaccError('Can\'t build parsing tables') + + if cls.debugfile: + with open(cls.debugfile, 'w') as f: + f.write(str(cls._grammar)) + f.write('\n') + f.write(str(cls._lrtable)) + cls.log.info('Parser debugging for %s written to %s', cls.__qualname__, cls.debugfile) + + # ---------------------------------------------------------------------- + # Parsing Support. This is the parsing runtime that users use to + # ---------------------------------------------------------------------- + def error(self, token, expected_tokens=None): + ''' + Default error handling function. This may be subclassed. + ''' + if token: + lineno = getattr(token, 'lineno', 0) + if lineno: + sys.stderr.write(f'sly: Syntax error at line {lineno}, token={token.type}\n') + else: + sys.stderr.write(f'sly: Syntax error, token={token.type}') + else: + sys.stderr.write('sly: Parse error in input. EOF\n') + + def errok(self): + ''' + Clear the error status + ''' + self.errorok = True + + def restart(self): + ''' + Force the parser to restart from a fresh state. Clears the statestack + ''' + del self.statestack[:] + del self.symstack[:] + sym = YaccSymbol() + sym.type = '$end' + self.symstack.append(sym) + self.statestack.append(0) + self.state = 0 + + def parse(self, tokens): + ''' + Parse the given input tokens. + ''' + lookahead = None # Current lookahead symbol + lookaheadstack = [] # Stack of lookahead symbols + actions = self._lrtable.lr_action # Local reference to action table (to avoid lookup on self.) + goto = self._lrtable.lr_goto # Local reference to goto table (to avoid lookup on self.) + prod = self._grammar.Productions # Local reference to production list (to avoid lookup on self.) + defaulted_states = self._lrtable.defaulted_states # Local reference to defaulted states + pslice = YaccProduction(None) # Production object passed to grammar rules + errorcount = 0 # Used during error recovery + + # Set up the state and symbol stacks + self.tokens = tokens + self.used_tokens = [] + self.statestack = statestack = [] # Stack of parsing states + self.symstack = symstack = [] # Stack of grammar symbols + pslice._stack = symstack # Associate the stack with the production + self.restart() + + # Set up position tracking + track_positions = self.track_positions + if not hasattr(self, '_line_positions'): + self._line_positions = { } # id: -> lineno + self._index_positions = { } # id: -> (start, end) + + errtoken = None # Err token + while True: + # Get the next symbol on the input. If a lookahead symbol + # is already set, we just use that. Otherwise, we'll pull + # the next token off of the lookaheadstack or from the lexer + if self.state not in defaulted_states: + if not lookahead: + if not lookaheadstack: + lookahead = next(tokens, None) # Get the next token + self.used_tokens.append(lookahead) + else: + lookahead = lookaheadstack.pop() + if not lookahead: + lookahead = YaccSymbol() + lookahead.type = '$end' + + # Check the action table + ltype = lookahead.type + t = actions[self.state].get(ltype) + else: + t = defaulted_states[self.state] + + if t is not None: + if t > 0: + # shift a symbol on the stack + statestack.append(t) + self.state = t + + symstack.append(lookahead) + lookahead = None + + # Decrease error count on successful shift + if errorcount: + errorcount -= 1 + continue + + if t < 0: + # reduce a symbol on the stack, emit a production + self.production = p = prod[-t] + pname = p.name + plen = p.len + pslice._namemap = p.namemap + + # Call the production function + pslice._slice = symstack[-plen:] if plen else [] + + sym = YaccSymbol() + sym.type = pname + value = p.func(self, pslice) + if value is pslice: + value = (pname, *(s.value for s in pslice._slice)) + + sym.value = value + + # Record positions + if track_positions: + if plen: + sym.lineno = symstack[-plen].lineno + sym.index = symstack[-plen].index + sym.end = symstack[-1].end + else: + # A zero-length production (what to put here?) + sym.lineno = None + sym.index = None + sym.end = None + self._line_positions[id(value)] = sym.lineno + self._index_positions[id(value)] = (sym.index, sym.end) + + if plen: + del symstack[-plen:] + del statestack[-plen:] + + symstack.append(sym) + self.state = goto[statestack[-1]][pname] + statestack.append(self.state) + continue + + if t == 0: + n = symstack[-1] + result = getattr(n, 'value', None) + return result + + if t is None: + # We have some kind of parsing error here. To handle + # this, we are going to push the current token onto + # the tokenstack and replace it with an 'error' token. + # If there are any synchronization rules, they may + # catch it. + # + # In addition to pushing the error token, we call call + # the user defined error() function if this is the + # first syntax error. This function is only called if + # errorcount == 0. + if errorcount == 0 or self.errorok: + errorcount = ERROR_COUNT + self.errorok = False + if lookahead.type == '$end': + errtoken = None # End of file! + else: + errtoken = lookahead + + tok = self.error(errtoken, expected_tokens=list(actions[self.state].keys())) + if tok: + # User must have done some kind of panic + # mode recovery on their own. The + # returned token is the next lookahead + lookahead = tok + self.errorok = True + continue + else: + # If at EOF. We just return. Basically dead. + if not errtoken: + return + else: + # Reset the error count. Unsuccessful token shifted + errorcount = ERROR_COUNT + + # case 1: the statestack only has 1 entry on it. If we're in this state, the + # entire parse has been rolled back and we're completely hosed. The token is + # discarded and we just keep going. + + if len(statestack) <= 1 and lookahead.type != '$end': + lookahead = None + self.state = 0 + # Nuke the lookahead stack + del lookaheadstack[:] + continue + + # case 2: the statestack has a couple of entries on it, but we're + # at the end of the file. nuke the top entry and generate an error token + + # Start nuking entries on the stack + if lookahead.type == '$end': + # Whoa. We're really hosed here. Bail out + return + + if lookahead.type != 'error': + sym = symstack[-1] + if sym.type == 'error': + # Hmmm. Error is on top of stack, we'll just nuke input + # symbol and continue + lookahead = None + continue + + # Create the error symbol for the first time and make it the new lookahead symbol + t = YaccSymbol() + t.type = 'error' + + if hasattr(lookahead, 'lineno'): + t.lineno = lookahead.lineno + if hasattr(lookahead, 'index'): + t.index = lookahead.index + if hasattr(lookahead, 'end'): + t.end = lookahead.end + t.value = lookahead + lookaheadstack.append(lookahead) + lookahead = t + else: + sym = symstack.pop() + statestack.pop() + self.state = statestack[-1] + continue + + # Call an error function here + raise RuntimeError('sly: internal parser error!!!\n') + + # Return position tracking information + def line_position(self, value): + return self._line_positions[id(value)] + + def index_position(self, value): + return self._index_positions[id(value)] + diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_base_sql/__init__.py b/tests/test_base_sql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_base_sql/test_ast.py b/tests/test_base_sql/test_ast.py new file mode 100644 index 0000000..6c2f234 --- /dev/null +++ b/tests/test_base_sql/test_ast.py @@ -0,0 +1,26 @@ +from mindsdb_sql_parser.ast import * + + +class TestAST: + def test_copy(self): + ast = Select( + targets=[Identifier.from_path_str('col1')], + from_table=Identifier.from_path_str('tab'), + where=BinaryOperation( + op='=', + args=( + Identifier.from_path_str('col3'), + Constant(1), + ) + ), + ) + + ast2 = ast.copy() + # same tree + assert ast.to_tree() == ast2.to_tree() + # not same objects + assert not ast.to_tree() is ast2.to_tree() + + # change + ast.where.args[0] = Constant(1) + assert ast.to_tree() != ast2.to_tree() diff --git a/tests/test_base_sql/test_base_lexer.py b/tests/test_base_sql/test_base_lexer.py new file mode 100644 index 0000000..6a2b540 --- /dev/null +++ b/tests/test_base_sql/test_base_lexer.py @@ -0,0 +1,333 @@ +from mindsdb_sql_parser.lexer import MindsDBLexer + +lexer = MindsDBLexer() + + +class TestLexer: + def test_select_basic(self): + + sql = f'SELECT 1' + tokens = list(lexer.tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[0].value == 'SELECT' + + assert tokens[1].type == 'INTEGER' + assert tokens[1].value == "1" + + sql = f'select 1' + tokens = list(lexer.tokenize(sql)) + assert tokens[0].type == 'SELECT' + assert tokens[1].type == 'INTEGER' + assert tokens[1].value == "1" + + sql = f'select a' + tokens = list(lexer.tokenize(sql)) + assert tokens[0].type == 'SELECT' + assert tokens[1].type == 'ID' + assert tokens[1].value == 'a' + + def test_select_basic_ignored_symbols(self): + + sql = f'SELECT \t\r\n1' + tokens = list(lexer.tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[0].value == 'SELECT' + + assert tokens[1].type == 'INTEGER' + assert tokens[1].value == "1" + + sql = f'select 1' + tokens = list(lexer.tokenize(sql)) + assert tokens[0].type == 'SELECT' + assert tokens[1].type == 'INTEGER' + assert tokens[1].value == "1" + + sql = f'select a' + tokens = list(lexer.tokenize(sql)) + assert tokens[0].type == 'SELECT' + assert tokens[1].type == 'ID' + assert tokens[1].value == 'a' + + def test_select_identifiers(self): + sql = 'SELECT abcd123, 123abcd, __whatisthi123s__, idwith$sign, `spaces in id`, multipleparts__whoa, `multiple_parts with brackets` ' + tokens = list(lexer.tokenize(sql)) + assert tokens[0].type == 'SELECT' + + for i, t in enumerate(tokens[1:]): + if i % 2 != 0: + assert t.type == 'COMMA' + else: + assert t.type == 'ID' + + def test_select_float(self): + for val in [0.0, 1.000, 0.1, 1.0, 99999.9999]: + val = str(val) + sql = f'SELECT {val}' + tokens = list(lexer.tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[0].value == 'SELECT' + + assert tokens[1].type == 'FLOAT' + assert tokens[1].value == val + + def test_select_strings(self): + sql = 'SELECT "a", "b", "c"' + tokens = list(lexer.tokenize(sql)) + assert tokens[0].type == 'SELECT' + assert tokens[1].type == 'DQUOTE_STRING' + assert tokens[1].value == '"a"' + assert tokens[2].type == 'COMMA' + assert tokens[3].type == 'DQUOTE_STRING' + assert tokens[3].value == '"b"' + assert tokens[5].type == 'DQUOTE_STRING' + assert tokens[5].value == '"c"' + + sql = "SELECT 'a', 'b', 'c'" + tokens = list(lexer.tokenize(sql)) + assert tokens[0].type == 'SELECT' + assert tokens[1].type == 'QUOTE_STRING' + assert tokens[1].value == "'a'" + assert tokens[2].type == 'COMMA' + assert tokens[3].type == 'QUOTE_STRING' + assert tokens[3].value == "'b'" + assert tokens[5].type == 'QUOTE_STRING' + assert tokens[5].value == "'c'" + + def test_select_strings_nested(self): + sql = "SELECT '\"a\"', \"'b'\" " + tokens = list(lexer.tokenize(sql)) + assert tokens[0].type == 'SELECT' + assert tokens[1].type == 'QUOTE_STRING' + assert tokens[1].value == "'\"a\"'" + assert tokens[2].type == 'COMMA' + assert tokens[3].type == 'DQUOTE_STRING' + assert tokens[3].value == '\"\'b\'\"' + + def test_binary_ops(self): + for op, expected_type in [ + ('+', 'PLUS'), + ('-', 'MINUS'), + ('/', 'DIVIDE'), + ('*', 'STAR'), + ('%', 'MODULO'), + ('=', 'EQUALS'), + ('!=', 'NEQUALS'), + ('<>', 'NEQUALS'), + ('>', 'GREATER'), + ('>=', 'GEQ'), + ('<', 'LESS'), + ('<=', 'LEQ'), + ('AND', 'AND'), + ('OR', 'OR'), + ('IS', 'IS'), + # ('IS NOT', 'ISNOT'), + ('LIKE', 'LIKE'), + ('IN', 'IN'), + ]: + sql = f'SELECT 1 {op} 2' + tokens = list(lexer.tokenize(sql)) + assert tokens[0].type == 'SELECT' + assert tokens[0].value == 'SELECT' + + assert tokens[1].type == 'INTEGER' + assert tokens[1].value == "1" + + assert tokens[2].type == expected_type + assert tokens[2].value == op + + assert tokens[3].type == 'INTEGER' + assert tokens[3].value == "2" + + def test_binary_ops_not(self): + + sql = f'SELECT 1 IS NOT 2' + tokens = list(lexer.tokenize(sql)) + assert tokens[0].type == 'SELECT' + assert tokens[0].value == 'SELECT' + + assert tokens[1].type == 'INTEGER' + assert tokens[1].value == "1" + + assert tokens[2].type == 'IS_NOT' + + assert tokens[3].type == 'INTEGER' + assert tokens[3].value == "2" + + # + sql = f'SELECT 1 NOT IN 2' + tokens = list(lexer.tokenize(sql)) + assert tokens[0].type == 'SELECT' + assert tokens[0].value == 'SELECT' + + assert tokens[1].type == 'INTEGER' + assert tokens[1].value == "1" + + assert tokens[2].type == 'NOT_IN' + + assert tokens[3].type == 'INTEGER' + assert tokens[3].value == "2" + + + def test_select_from(self): + sql = f'SELECT column AS other_column FROM db.schema.tab' + tokens = list(lexer.tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[0].value == 'SELECT' + + assert tokens[1].type == 'ID' + assert tokens[1].value == 'column' + + assert tokens[2].type == 'AS' + assert tokens[2].value == 'AS' + + assert tokens[3].type == 'ID' + assert tokens[3].value == 'other_column' + + assert tokens[4].type == 'FROM' + assert tokens[4].value == 'FROM' + + assert tokens[5].type == 'ID' + assert tokens[5].value == 'db' + + assert tokens[6].type == 'DOT' + assert tokens[6].value == '.' + + assert tokens[7].type == 'SCHEMA' + + assert tokens[8].type == 'DOT' + assert tokens[8].value == '.' + + assert tokens[9].type == 'ID' + assert tokens[9].value == 'tab' + + def test_select_star(self): + sql = f'SELECT * FROM tab' + tokens = list(lexer.tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[0].value == 'SELECT' + + assert tokens[1].type == 'STAR' + assert tokens[1].value == '*' + + assert tokens[2].type == 'FROM' + assert tokens[2].value == 'FROM' + + assert tokens[3].type == 'ID' + assert tokens[3].value == 'tab' + + def test_select_where(self): + sql = f'SELECT column FROM tab WHERE column = "something"' + tokens = list(lexer.tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[1].type == 'ID' + assert tokens[2].type == 'FROM' + assert tokens[3].type == 'ID' + assert tokens[4].type == 'WHERE' + assert tokens[4].value == 'WHERE' + assert tokens[5].type == 'ID' + assert tokens[5].value == 'column' + assert tokens[6].type == 'EQUALS' + assert tokens[6].value == '=' + assert tokens[7].type == 'DQUOTE_STRING' + assert tokens[7].value == '"something"' + + def test_select_group_by(self): + sql = f'SELECT column, sum(column2) FROM tab GROUP BY column' + tokens = list(lexer.tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[1].type == 'ID' + assert tokens[2].type == 'COMMA' + assert tokens[3].type == 'ID' + assert tokens[4].type == 'LPAREN' + assert tokens[5].type == 'ID' + assert tokens[6].type == 'RPAREN' + assert tokens[7].type == 'FROM' + assert tokens[8].type == 'ID' + assert tokens[9].type == 'GROUP_BY' + assert tokens[9].value == 'GROUP BY' + assert tokens[10].type == 'ID' + assert tokens[10].value == 'column' + + def test_select_order_by(self): + for order_dir in ['ASC', 'DESC']: + sql = f'SELECT column, sum(column2) FROM tab ORDER BY column {order_dir}' + tokens = list(lexer.tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[1].type == 'ID' + assert tokens[2].type == 'COMMA' + assert tokens[3].type == 'ID' + assert tokens[4].type == 'LPAREN' + assert tokens[5].type == 'ID' + assert tokens[6].type == 'RPAREN' + assert tokens[7].type == 'FROM' + assert tokens[8].type == 'ID' + assert tokens[9].type == 'ORDER_BY' + assert tokens[9].value == 'ORDER BY' + assert tokens[10].type == 'ID' + assert tokens[10].value == 'column' + assert tokens[11].type == order_dir + assert tokens[11].value == order_dir + + def test_as_ones(self): + sql = "SELECT *, (SELECT 1) AS ones FROM t1" + tokens = list(lexer.tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[1].type == 'STAR' + assert tokens[2].type == 'COMMA' + assert tokens[3].type == 'LPAREN' + assert tokens[4].type == 'SELECT' + assert tokens[5].type == 'INTEGER' + assert tokens[6].type == 'RPAREN' + assert tokens[7].type == 'AS' + assert tokens[8].type == 'ID' + assert tokens[9].type == 'FROM' + assert tokens[10].type == 'ID' + + sql = "SELECT *, (SELECT 1) AS ones FROM t1".lower() + tokens = list(lexer.tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[1].type == 'STAR' + assert tokens[2].type == 'COMMA' + assert tokens[3].type == 'LPAREN' + assert tokens[4].type == 'SELECT' + assert tokens[5].type == 'INTEGER' + assert tokens[6].type == 'RPAREN' + assert tokens[7].type == 'AS' + assert tokens[8].type == 'ID' + assert tokens[9].type == 'FROM' + assert tokens[10].type == 'ID' + + def test_select_parameter(self): + sql = f'SELECT ?' + tokens = list(lexer.tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[0].value == 'SELECT' + + assert tokens[1].type == 'PARAMETER' + assert tokens[1].value == '?' + + def test_show_character_set(self): + sql = "show character set where charset = 'utf8mb4'" + tokens = list(lexer.tokenize(sql)) + + assert tokens[0].type == 'SHOW' + assert tokens[1].type == 'CHARACTER' + assert tokens[2].type == 'SET' + assert tokens[3].type == 'WHERE' + assert tokens[4].type == 'CHARSET' + assert tokens[4].value == 'charset' + assert tokens[5].value == '=' + assert tokens[6].type == 'QUOTE_STRING' + assert tokens[6].value == "'utf8mb4'" + diff --git a/tests/test_base_sql/test_base_sql.py b/tests/test_base_sql/test_base_sql.py new file mode 100644 index 0000000..7a3816c --- /dev/null +++ b/tests/test_base_sql/test_base_sql.py @@ -0,0 +1,88 @@ +from textwrap import dedent +from mindsdb_sql_parser import parse_sql + +from mindsdb_sql_parser.ast import * + + +class TestSql: + def test_ending(self): + sql = """INSERT INTO tbl_name VALUES (1, 3) + ; + """ + + parse_sql(sql) + + def test_not_equal(self): + sql = " select * from t1 where a<>1" + + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Star()], + from_table=Identifier('t1'), + where=BinaryOperation( + op='<>', + args=[ + Identifier('a'), + Constant(1) + ] + ) + ) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + def test_escaping(self): + expected_ast = Select( + targets=[ + Constant(value="a ' \" b"), + Constant(value="a ' \" b"), + Constant(value="a \\n b"), + Constant(value="a \\\n b"), + Constant(value="a \\\n b"), + Constant(value="a\nb"), + ] + ) + + sql = dedent(''' +select +'a \\' \\" b', -- double quote +"a \\' \\" b", -- single quote +"a \\n b", +"a \\\n b", -- double quote +'a \\\n b', -- single quote +"a +b" + ''') + + ast = parse_sql(sql) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + def test_quotes_escaping(self): + sql = "select 'women''s soccer'" + + expected_ast = Select( + targets=[ + Constant(value="women's soccer") + ] + ) + ast = parse_sql(sql) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + def test_quotes_identifier(self): + sql = 'select t2."var (k)" from t2' + + expected_ast = Select( + targets=[ + Identifier(parts=['t2', 'var (k)']) + ], + from_table=Identifier('t2') + ) + ast = parse_sql(sql) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_base_sql/test_create.py b/tests/test_base_sql/test_create.py new file mode 100644 index 0000000..110c789 --- /dev/null +++ b/tests/test_base_sql/test_create.py @@ -0,0 +1,156 @@ +import pytest + +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * + + +class TestCreate: + def test_create_from_select(self): + expected_ast = CreateTable( + name=Identifier('int1.model_name'), + is_replace=True, + from_select=Select( + targets=[Identifier('a')], + from_table=Identifier('ddd'), + ) + ) + + # with parens + sql = ''' + create or replace table int1.model_name ( + select a from ddd + ) + ''' + ast = parse_sql(sql) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + # without parens + sql = ''' + create or replace table int1.model_name + select a from ddd + ''' + ast = parse_sql(sql) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + expected_ast.is_replace = False + + # no replace + sql = ''' + create table int1.model_name + select a from ddd + ''' + ast = parse_sql(sql) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + +class TestCreateMindsdb: + + def test_create(self): + + for is_replace in [True, False]: + for if_not_exists in [True, False]: + + expected_ast = CreateTable( + name=Identifier('mydb.Persons'), + is_replace=is_replace, + if_not_exists=if_not_exists, + columns=[ + TableColumn(name='PersonID', type='int'), + TableColumn(name='LastName', type='varchar', length=255), + TableColumn(name='FirstName', type='char', length=10), + TableColumn(name='Info', type='json'), + TableColumn(name='City', type='varchar'), + ] + ) + replace_str = 'OR REPLACE' if is_replace else '' + exist_str = 'IF NOT EXISTS' if if_not_exists else '' + + sql = f''' + CREATE {replace_str} TABLE {exist_str} mydb.Persons( + PersonID int, + LastName varchar(255), + FirstName char(10), + Info json, + City varchar + ) + ''' + ast = parse_sql(sql) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + # test with primary keys / defaults + # using serial + + sql = f''' + CREATE TABLE mydb.Persons( + PersonID serial, + active BOOL NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''' + ast = parse_sql(sql) + + expected_ast = CreateTable( + name=Identifier('mydb.Persons'), + columns=[ + TableColumn(name='PersonID', type='serial'), + TableColumn(name='active', type='BOOL', nullable=False), + TableColumn(name='created_at', type='TIMESTAMP', default='CURRENT_TIMESTAMP'), + ] + ) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + # using primary key column + + sql = f''' + CREATE TABLE mydb.Persons( + PersonID INT PRIMARY KEY, + name TEXT NULL + ) + ''' + ast = parse_sql(sql) + + expected_ast = CreateTable( + name=Identifier('mydb.Persons'), + columns=[ + TableColumn(name='PersonID', type='INT', is_primary_key=True), + TableColumn(name='name', type='TEXT', nullable=True), + ] + ) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + # multiple primary keys + + sql = f''' + CREATE TABLE mydb.Persons( + location_id INT, + num INT, + name TEXT, + PRIMARY KEY (location_id, num) + ) + ''' + ast = parse_sql(sql) + + expected_ast = CreateTable( + name=Identifier('mydb.Persons'), + columns=[ + TableColumn(name='location_id', type='INT', is_primary_key=True), + TableColumn(name='num', type='INT', is_primary_key=True), + TableColumn(name='name', type='TEXT'), + ] + ) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + diff --git a/tests/test_base_sql/test_ddl.py b/tests/test_base_sql/test_ddl.py new file mode 100644 index 0000000..f21b945 --- /dev/null +++ b/tests/test_base_sql/test_ddl.py @@ -0,0 +1,66 @@ +import pytest + +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * + +class TestDDL: + + def test_drop_database(self): + + sql = "DROP DATABASE IF EXISTS dbname" + + ast = parse_sql(sql) + expected_ast = DropDatabase(name=Identifier('dbname'), if_exists=True) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + + sql = "DROP DATABASE dbname" + + ast = parse_sql(sql) + expected_ast = DropDatabase(name=Identifier('dbname'), if_exists=False) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + + # DROP SCHEMA is a synonym for DROP DATABASE. + sql = "DROP SCHEMA dbname" + + ast = parse_sql(sql) + expected_ast = DropDatabase(name=Identifier('dbname')) + + assert str(ast).lower() == 'DROP DATABASE dbname'.lower() + assert ast.to_tree() == expected_ast.to_tree() + + def test_drop_view(self): + + sql = "DROP VIEW IF EXISTS vname1, vname2" + + ast = parse_sql(sql) + expected_ast = DropView(names=[Identifier('vname1'), Identifier('vname2')], if_exists=True) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + + sql = "DROP VIEW vname" + + ast = parse_sql(sql) + expected_ast = DropView(names=[Identifier('vname')], if_exists=False) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + + def test_drop_predictor_table_syntax_ok(self): + sql = "DROP TABLE mindsdb.tbl" + ast = parse_sql(sql) + expected_ast = DropTables(tables=[Identifier('mindsdb.tbl')]) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + sql = "DROP TABLE if exists mindsdb.tbl" + ast = parse_sql(sql) + expected_ast = DropTables(tables=[Identifier('mindsdb.tbl')], if_exists=True) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + diff --git a/tests/test_base_sql/test_delete.py b/tests/test_base_sql/test_delete.py new file mode 100644 index 0000000..c2b7b80 --- /dev/null +++ b/tests/test_base_sql/test_delete.py @@ -0,0 +1,34 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * + + +class TestDelete: + def test_delete(self): + sql = "delete from ds.table1 where field > value" + + ast = parse_sql(sql) + expected_ast = Delete( + table=Identifier('ds.table1'), + where=BinaryOperation( + op='>', + args=( + Identifier('field'), + Identifier('value'), + ) + ), + ) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + + def test_delete_no_where(self): + sql = "delete from ds.table1" + + ast = parse_sql(sql) + expected_ast = Delete( + table=Identifier('ds.table1'), + where=None + ) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() \ No newline at end of file diff --git a/tests/test_base_sql/test_describe.py b/tests/test_base_sql/test_describe.py new file mode 100644 index 0000000..56484c4 --- /dev/null +++ b/tests/test_base_sql/test_describe.py @@ -0,0 +1,68 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * + + +class TestDescribe: + def test_describe(self): + sql = "DESCRIBE my_identifier" + ast = parse_sql(sql) + expected_ast = Describe(value=Identifier('my_identifier')) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + +class TestDescribeMindsdb: + def test_describe_predictor(self): + sql = "DESCRIBE PREDICTOR my_identifier" + ast = parse_sql(sql) + expected_ast = Describe(type='PREDICTOR', value=Identifier('my_identifier')) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + sql = "DESCRIBE MODEL my_identifier" + ast = parse_sql(sql) + + expected_ast = Describe(type='MODEL', value=Identifier('my_identifier')) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # describe attr + sql = "DESCRIBE MODEL pred.attr" + ast = parse_sql(sql) + + expected_ast = Describe(type='MODEL', value=Identifier(parts=['pred', 'attr'])) + + assert str(ast) == str(expected_ast) + + # version + sql = "DESCRIBE MODEL pred.11" + ast = parse_sql(sql) + + expected_ast = Describe(type='MODEL', value=Identifier(parts=['pred', '11'])) + + assert str(ast) == str(expected_ast) + + # version and attr + sql = "DESCRIBE MODEL pred.11.attr" + ast = parse_sql(sql) + + expected_ast = Describe(type='MODEL', value=Identifier(parts=['pred', '11', 'attr'])) + + assert str(ast) == str(expected_ast) + + # other objects + for obj in ( + "AGENT", "JOB", "SKILL", "CHATBOT", "TRIGGER", "VIEW", + "KNOWLEDGE_BASE", "KNOWLEDGE BASE", "PREDICTOR", "MODEL", + 'database', 'project', 'handler', 'ml_engine'): + + sql = f"DESCRIBE {obj} aaa" + ast = parse_sql(sql) + + obj = obj.replace(' ', '_') + expected_ast = Describe(type=obj, value=Identifier(parts=['aaa'])) + + assert str(ast) == str(expected_ast) diff --git a/tests/test_base_sql/test_insert.py b/tests/test_base_sql/test_insert.py new file mode 100644 index 0000000..fb8c5d9 --- /dev/null +++ b/tests/test_base_sql/test_insert.py @@ -0,0 +1,96 @@ +import pytest + +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * + + +class TestInsert: + + def test_insert(self): + sql = "INSERT INTO tbl_name(a, c) VALUES (1, 3), (4, 5)" + + ast = parse_sql(sql) + expected_ast = Insert( + table=Identifier('tbl_name'), + columns=[Identifier('a'), Identifier('c')], + values=[ + [Constant(1), Constant(3)], + [Constant(4), Constant(5)], + ] + ) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + + def test_insert_no_columns(self): + sql = "INSERT INTO tbl_name VALUES (1, 3), (4, 5)" + ast = parse_sql(sql) + expected_ast = Insert( + table=Identifier('tbl_name'), + values=[ + [Constant(1), Constant(3)], + [Constant(4), Constant(5)], + ] + ) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + def test_insert_from_select(self): + sql = "INSERT INTO tbl_name(a, c) SELECT b, d from table2" + + ast = parse_sql(sql) + expected_ast = Insert( + table=Identifier('tbl_name'), + columns=[Identifier('a'), Identifier('c')], + from_select=Select( + targets=[ + Identifier('b'), + Identifier('d'), + ], + from_table=Identifier('table2') + ) + ) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + + def test_insert_from_select_no_columns(self): + sql = "INSERT INTO tbl_name SELECT b, d from table2" + + ast = parse_sql(sql) + expected_ast = Insert( + table=Identifier('tbl_name'), + from_select=Select( + targets=[ + Identifier('b'), + Identifier('d'), + ], + from_table=Identifier('table2') + ) + ) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + +class TestInsertMDB: + + def test_insert_from_union(self): + from textwrap import dedent + sql = dedent(""" + INSERT INTO tbl_name(a, c) SELECT * from table1 + UNION + SELECT * from table2""")[1:] + + ast = parse_sql(sql) + expected_ast = Insert( + table=Identifier('tbl_name'), + columns=[Identifier('a'), Identifier('c')], + from_select=Union( + left=Select(targets=[Star()], from_table=Identifier('table1')), + right=Select(targets=[Star()], from_table=Identifier('table2')) + ) + ) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() \ No newline at end of file diff --git a/tests/test_base_sql/test_misc_sql_queries.py b/tests/test_base_sql/test_misc_sql_queries.py new file mode 100644 index 0000000..e7f30a0 --- /dev/null +++ b/tests/test_base_sql/test_misc_sql_queries.py @@ -0,0 +1,259 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * + + +class TestMiscQueries: + def test_set(self): + + sql = "SET names some_name" + + ast = parse_sql(sql) + expected_ast = Set(category="names", value=Identifier('some_name')) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = "set character_set_results = NULL" + + ast = parse_sql(sql) + expected_ast = Set(name=Identifier('character_set_results'), value=NullConstant()) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_start_transaction(self): + sql = "start transaction" + + ast = parse_sql(sql) + expected_ast = StartTransaction() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_rollback(self): + sql = "rollback" + + ast = parse_sql(sql) + expected_ast = RollbackTransaction() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_commit(self): + sql = "commit" + + ast = parse_sql(sql) + expected_ast = CommitTransaction() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_explain(self): + sql = "explain some_table" + + ast = parse_sql(sql) + expected_ast = Explain(target=Identifier('some_table')) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_alter_table_keys(self): + sql = "alter table some_table disable keys" + + ast = parse_sql(sql) + expected_ast = AlterTable(target=Identifier('some_table'), arg='disable keys') + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = "alter table some_table enable keys" + + ast = parse_sql(sql) + expected_ast = AlterTable(target=Identifier('some_table'), arg='enable keys') + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_autocommit(self): + sql = "set autocommit=1" + + ast = parse_sql(sql) + expected_ast = Set( + name=Identifier('autocommit'), + value=Constant(1) + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + + +class TestMiscQueriesNoSqlite: + def test_set(self): + + sql = "set var1 = NULL, var2 = 10" + + ast = parse_sql(sql) + expected_ast = Set( + set_list=[ + Set(name=Identifier('var1'), value=NullConstant()), + Set(name=Identifier('var2'), value=Constant(10)), + ] + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + + sql = "SET NAMES some_name collate DEFAULT" + + ast = parse_sql(sql) + expected_ast = Set(category="NAMES", + value=Constant('some_name', with_quotes=False), + params={'COLLATE': 'DEFAULT'}) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = "SET names some_name collate 'utf8mb4_general_ci'" + + ast = parse_sql(sql) + expected_ast = Set(category="names", + value=Constant('some_name', with_quotes=False), + params={'COLLATE': Constant('utf8mb4_general_ci')}) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_set_charset(self): + + sql = "SET CHARACTER SET DEFAULT" + + ast = parse_sql(sql) + expected_ast = Set(category='CHARSET', value=Constant('DEFAULT', with_quotes=False)) + + assert ast.to_tree() == expected_ast.to_tree() + + sql = "SET CHARSET DEFAULT" + + ast = parse_sql(sql) + expected_ast = Set(category='CHARSET', value=Constant('DEFAULT', with_quotes=False)) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = "SET CHARSET 'utf8'" + + ast = parse_sql(sql) + expected_ast = Set(category='CHARSET', value=Constant('utf8')) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_set_transaction(self): + + sql = "SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE" + + ast = parse_sql(sql) + expected_ast = Set( + category='TRANSACTION', + params={ + 'isolation level': 'REPEATABLE READ', + 'access_mode': 'READ WRITE', + }, + scope='GLOBAL' + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = "SET SESSION TRANSACTION READ ONLY, ISOLATION LEVEL SERIALIZABLE" + + ast = parse_sql(sql) + + expected_ast = Set( + category='TRANSACTION', + params={ + 'isolation level': 'SERIALIZABLE', + 'access_mode': 'READ ONLY', + }, + scope='SESSION' + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = "SET TRANSACTION ISOLATION LEVEL READ UNCOMMITTED" + + ast = parse_sql(sql) + + expected_ast = Set( + category='TRANSACTION', + params={ + 'isolation level': 'READ UNCOMMITTED' + }, + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = "SET TRANSACTION READ ONLY" + + ast = parse_sql(sql) + + expected_ast = Set( + category='TRANSACTION', + params=dict( + access_mode='READ ONLY', + ) + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_begin(self): + sql = "begin" + + ast = parse_sql(sql) + expected_ast = StartTransaction() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + +class TestMindsdb: + def test_charset(self): + sql = "SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci" + + ast = parse_sql(sql) + expected_ast = Set(category="NAMES", + value=Constant('utf8mb4', with_quotes=False), + params={'COLLATE': Constant('utf8mb4_unicode_ci', with_quotes=False)}) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_set_version(self): + sql = "SET active model_name.1" + + ast = parse_sql(sql) + expected_ast = Set(category='active', value=Identifier(parts=['model_name', '1'])) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + + def test_interval(self): + for value in ('1 day', "'1' day", "'1 day'"): + sql = f""" + select interval {value} + 1 from aaa + where 'a' > interval "1 min" + """ + + expected_ast = Select( + targets=[ + BinaryOperation(op='+', args=[ + Interval('1 day'), + Constant(1) + ]) + ], + from_table=Identifier('aaa'), + where=BinaryOperation( + op='>', + args=[ + Constant('a'), + Interval('1 min'), + ] + ) + ) + + ast = parse_sql(sql) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + diff --git a/tests/test_base_sql/test_select_common_table_expression.py b/tests/test_base_sql/test_select_common_table_expression.py new file mode 100644 index 0000000..01c1e16 --- /dev/null +++ b/tests/test_base_sql/test_select_common_table_expression.py @@ -0,0 +1,89 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.utils import JoinType + + +class TestCommonTableExpression: + + def test_cte_select_number(self): + sql = f'WITH one AS ( SELECT 1 ) SELECT * FROM one' + ast = parse_sql(sql) + + expected_ast = Select( + cte=[ + CommonTableExpression(name=Identifier('one'), query=Select(targets=[Constant(1)])), + ], + targets=[Star()], + from_table=Identifier('one') + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_cte_select_named_columns(self): + sql = f'WITH cte( a, b ) AS ( SELECT 1, 2 ) SELECT a, b FROM cte' + ast = parse_sql(sql) + + expected_ast = Select( + cte=[ + CommonTableExpression(name=Identifier('cte'), + columns=[Identifier('a'), Identifier('b')], + query=Select(targets=[Constant(1), Constant(2)])), + ], + targets=[Identifier('a'), Identifier('b')], + from_table=Identifier('cte') + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_cte_multiple(self): + sql = '''WITH cte_a AS ( SELECT 1 ), cte_b AS ( SELECT 2 ) SELECT * FROM cte_a, cte_b''' + ast = parse_sql(sql) + + expected_ast = Select( + cte=[ + CommonTableExpression(name=Identifier('cte_a'), + query=Select(targets=[Constant(1)])), + CommonTableExpression(name=Identifier('cte_b'), + query=Select(targets=[Constant(2)])), + ], + targets=[Star()], + from_table=Join(left=Identifier('cte_a'), + right=Identifier('cte_b'), + join_type=JoinType.INNER_JOIN, + implicit=True) + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_cte_nested(self): + sql = '''WITH cte AS ( SELECT 1 ) SELECT * FROM (WITH cte_1 AS ( SELECT 2 ) SELECT * FROM cte_1 JOIN cte) AS subquery''' + ast = parse_sql(sql) + + expected_ast = Select( + cte=[ + CommonTableExpression(name=Identifier('cte'), + query=Select(targets=[Constant(1)])), + ], + targets=[Star()], + from_table=Select( + alias=Identifier('subquery'), + cte=[ + CommonTableExpression(name=Identifier('cte_1'), + query=Select(targets=[Constant(2)])), + ], + targets=[Star()], + from_table=Join(left=Identifier('cte_1'), + right=Identifier('cte'), + join_type=JoinType.JOIN) + ), + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_base_sql/test_select_operations.py b/tests/test_base_sql/test_select_operations.py new file mode 100644 index 0000000..9e08a6e --- /dev/null +++ b/tests/test_base_sql/test_select_operations.py @@ -0,0 +1,637 @@ +import pytest + +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * + + +class TestOperations: + def test_select_binary_operations(self): + for op in ['+', '-', '/', '*', '%', '=', '!=', '>', '<', '>=', '<=', + 'is', 'IS NOT', 'like', 'in', 'and', 'or', '||']: + sql = f'SELECT column1 {op.upper()} column2 FROM tab' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[BinaryOperation(op=op, + args=( + Identifier.from_path_str('column1'), Identifier.from_path_str('column2') + )), + ], + from_table=Identifier.from_path_str('tab') + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_operation_converts_to_lowercase(self): + sql = f'SELECT column1 IS column2 FROM tab' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[BinaryOperation(op='is', + args=( + Identifier.from_path_str('column1'), Identifier.from_path_str('column2') + )), + ], + from_table=Identifier.from_path_str('tab') + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_operator_precedence_sum_mult(self): + sql = f'SELECT column1 + column2 * column3' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[BinaryOperation(op='+', + args=( + Identifier.from_path_str('column1'), + BinaryOperation(op='*', + args=( + Identifier.from_path_str('column2'), Identifier.from_path_str('column3') + )), + + ), + ) + ] + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + sql = f'SELECT column1 * column2 + column3' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[BinaryOperation(op='+', + args=( + BinaryOperation(op='*', + args=( + Identifier.from_path_str('column1'), Identifier.from_path_str('column2') + )), + Identifier.from_path_str('column3'), + + ) + ) + ] + ) + + assert ast == expected_ast + assert ast.to_tree() == expected_ast.to_tree() + + + def test_operator_precedence_sum_mult_parentheses(self): + sql = f'SELECT (column1 + column2) * column3' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[BinaryOperation(op='*', + args=( + BinaryOperation(op='+', + args=( + Identifier.from_path_str('column1'), Identifier.from_path_str('column2') + ), + parentheses=True), + Identifier.from_path_str('column3'), + + ), + ) + ] + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_operator_chained_and(self): + sql = f"""SELECT column1 AND column2 AND column3""" + ast = parse_sql(sql) + + expected_ast = Select(targets=[BinaryOperation(op='AND', args=(BinaryOperation(op='and', args=( + Identifier.from_path_str("column1"), + Identifier.from_path_str("column2"))), + Identifier.from_path_str("column3"), + + ))]) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + def test_operator_precedence_or_and(self): + sql = f'SELECT column1 OR column2 AND column3' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[BinaryOperation(op='or', + args=(Identifier.from_path_str('column1'), + BinaryOperation(op='and', + args=( + Identifier.from_path_str('column2'), Identifier.from_path_str('column3') + )) + + ) + ) + ] + ) + + assert ast == expected_ast + assert ast.to_tree() == expected_ast.to_tree() + + sql = f'SELECT column1 AND column2 OR column3' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[BinaryOperation(op='or', + args=( + BinaryOperation(op='and', + args=( + Identifier.from_path_str('column1'), Identifier.from_path_str('column2') + )), + Identifier.from_path_str('column3'), + + ) + ) + ] + ) + + assert ast == expected_ast + assert ast.to_tree() == expected_ast.to_tree() + + def test_operator_precedence_or_and_parentheses(self): + sql = f'SELECT (column1 OR column2) AND column3' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[BinaryOperation(op='and', + args=( + BinaryOperation(op='or', + args=( + Identifier.from_path_str('column1'), Identifier.from_path_str('column2') + ), + parentheses=True), + Identifier.from_path_str('column3'), + ), + ) + ] + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_where_and_or_precedence(self): + sql = "SELECT col1 FROM tab WHERE col1 AND col2 OR col3" + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Identifier.from_path_str('col1')], + from_table=Identifier.from_path_str('tab'), + where=BinaryOperation(op='or', + args=( + BinaryOperation(op='and', + args=( + Identifier.from_path_str('col1'), + Identifier.from_path_str('col2'), + )), + Identifier.from_path_str('col3'), + + )) + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + sql = "SELECT col1 FROM tab WHERE col1 = 1 AND col2 = 1 OR col3 = 1" + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Identifier.from_path_str('col1')], + from_table=Identifier.from_path_str('tab'), + where=BinaryOperation(op='or', + args=( + BinaryOperation(op='and', + args=( + BinaryOperation(op='=', + args=( + Identifier.from_path_str('col1'), + Constant(1), + )), + BinaryOperation(op='=', + args=( + Identifier.from_path_str('col2'), + Constant(1), + )), + )), + BinaryOperation(op='=', + args=( + Identifier.from_path_str('col3'), + Constant(1), + )), + + )) + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_select_unary_operations(self): + for op in ['-', 'not']: + sql = f'SELECT {op} column FROM tab' + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert len(ast.targets) == 1 + assert isinstance(ast.targets[0], UnaryOperation) + assert ast.targets[0].op == op + assert len(ast.targets[0].args) == 1 + assert isinstance(ast.targets[0].args[0], Identifier) + assert ast.targets[0].args[0].parts == ['column'] + + assert str(ast).lower() == sql.lower() + + def test_select_function_one_arg(self): + funcs = ['sum', 'min', 'max', 'some_custom_function'] + for func in funcs: + sql = f'SELECT {func}(column) FROM tab' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Function(op=func, args=(Identifier.from_path_str('column'),))], + from_table=Identifier.from_path_str('tab'), + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_select_function_two_args(self): + funcs = ['sum', 'min', 'max', 'some_custom_function'] + for func in funcs: + sql = f'SELECT {func}(column1, column2) FROM tab' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Function(op=func, args=(Identifier.from_path_str('column1'),Identifier.from_path_str('column2')))], + from_table=Identifier.from_path_str('tab'), + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_select_in_operation(self): + sql = """SELECT * FROM t1 WHERE col1 IN ("a", "b")""" + + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert ast.where + + expected_where = BinaryOperation(op='IN', + args=[ + Identifier.from_path_str('col1'), + Tuple(items=[Constant('a'), Constant("b")]), + ]) + + assert ast.where.to_tree() == expected_where.to_tree() + assert ast.where == expected_where + + def test_unary_is_special_values(self): + args = [('NULL', NullConstant()), ('TRUE', Constant(value=True)), ('FALSE', Constant(value=False))] + for sql_arg, python_obj in args: + sql = f"""SELECT column1 IS {sql_arg}""" + ast = parse_sql(sql) + + expected_ast = Select(targets=[BinaryOperation(op='IS', args=(Identifier.from_path_str("column1"), python_obj))], ) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + + def test_unary_is_not_special_values(self): + args = [('NULL', NullConstant()), ('TRUE', Constant(value=True)), ('FALSE', Constant(value=False))] + for sql_arg, python_obj in args: + sql = f"""SELECT column1 IS NOT {sql_arg}""" + ast = parse_sql(sql) + + expected_ast = Select(targets=[BinaryOperation(op='IS NOT', args=(Identifier.from_path_str("column1"), python_obj))], ) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_not_in(self): + sql = f"""SELECT column1 NOT IN column2""" + ast = parse_sql(sql) + + expected_ast = Select(targets=[BinaryOperation(op='not in', args=(Identifier.from_path_str("column1"), Identifier.from_path_str("column2")))], ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_is_null(self): + sql = "SELECT col1 FROM t1 WHERE col1 IS NULL" + ast = parse_sql(sql) + + expected_ast = Select(targets=[Identifier.from_path_str("col1")], from_table=Identifier.from_path_str('t1'), + where=BinaryOperation('is', args=(Identifier.from_path_str('col1'), NullConstant()))) + + assert ast.to_tree() == expected_ast.to_tree() + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_is_not_null(self): + sql = "SELECT col1 FROM t1 WHERE col1 IS NOT NULL" + ast = parse_sql(sql) + + expected_ast = Select(targets=[Identifier.from_path_str("col1")], from_table=Identifier.from_path_str('t1'), + where=BinaryOperation('IS NOT', args=(Identifier.from_path_str('col1'), NullConstant()))) + assert ast.to_tree() == expected_ast.to_tree() + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_is_true(self): + sql = "SELECT col1 FROM t1 WHERE col1 IS TRUE" + ast = parse_sql(sql) + + expected_ast = Select(targets=[Identifier.from_path_str("col1")], from_table=Identifier.from_path_str('t1'), + where=BinaryOperation('is', args=(Identifier.from_path_str('col1'), Constant(True)))) + assert ast.to_tree() == expected_ast.to_tree() + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_is_false(self): + sql = "SELECT col1 FROM t1 WHERE col1 IS FALSE" + ast = parse_sql(sql) + + expected_ast = Select(targets=[Identifier.from_path_str("col1")], from_table=Identifier.from_path_str('t1'), + where=BinaryOperation('is', args=(Identifier.from_path_str('col1'), Constant(False)))) + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_between(self): + sql = "SELECT col1 FROM t1 WHERE col1 BETWEEN a AND b" + ast = parse_sql(sql) + + expected_ast = Select(targets=[Identifier.from_path_str("col1")], from_table=Identifier.from_path_str('t1'), + where=BetweenOperation(args=(Identifier.from_path_str('col1'), Identifier.from_path_str('a'), Identifier.from_path_str('b')))) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_between_with_and(self): + sql = "SELECT col1 FROM t1 WHERE col2 > 1 AND col1 BETWEEN a AND b" + ast = parse_sql(sql) + + expected_ast = Select(targets=[Identifier.from_path_str("col1")], + from_table=Identifier.from_path_str('t1'), + where=BinaryOperation('and', args=[ + BinaryOperation('>', args=[ + Identifier('col2'), + Constant(1), + ]), + BetweenOperation(args=( + Identifier.from_path_str('col1'), Identifier.from_path_str('a'), + Identifier.from_path_str('b'))), + ]) + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + + def test_select_status(self): + sql = 'select status from mindsdb.predictors' + ast = parse_sql(sql) + expected_ast = Select(targets=[Identifier.from_path_str("status")], + from_table=Identifier.from_path_str('mindsdb.predictors') + ) + assert ast.to_tree() == expected_ast.to_tree() + # assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + + def test_select_from_engines(self): + sql = 'select * from engines' + ast = parse_sql(sql) + expected_ast = Select(targets=[Star()], + from_table=Identifier.from_path_str('engines') + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_from_view_kw(self): + for table in ['view.t','views.t']: + sql = f'select * from {table}' + + ast = parse_sql(sql) + expected_ast = Select(targets=[Star()], + from_table=Identifier.from_path_str(table) + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_complex_precedence(self): + sql = ''' + SELECT * from tb + WHERE + not a=2+1 + and + b=c + or + d between e and f + and + g + ''' + ast = parse_sql(sql) + expected_ast = Select( + targets=[Star()], + from_table=Identifier.from_path_str('tb'), + where=BinaryOperation(op='or', args=( + BinaryOperation(op='and', args=( + UnaryOperation(op='not', args=[ + BinaryOperation(op='=', args=( + Identifier(parts=['a']), + BinaryOperation(op='+', args=( + Constant(value=2), + Constant(value=1) + )) + )) + ]), + BinaryOperation(op='=', args=( + Identifier(parts=['b']), + Identifier(parts=['c']) + )) + )), + BinaryOperation(op='and', args=( + BetweenOperation(args=( + Identifier(parts=['d']), + Identifier(parts=['e']), + Identifier(parts=['f']) + )), + Identifier(parts=['g']) + )) + )), + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_databases(self): + sql = f'SELECT name FROM information_schema.databases' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Identifier('name')], + from_table=Identifier('information_schema.databases'), + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + +# it doesn't work in sqlite + +class TestOperationsNoSqlite: + + def test_select_function_no_args(self): + sql = f'SELECT database() FROM tab' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Function(op='database', args=[])], + from_table=Identifier.from_path_str('tab'), + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_select_functions(self): + sqls = [ + "SELECT connection_id()", + "SELECT database()", + "SELECT current_user()", + "SELECT version()", + "SELECT user()", + ] + + for sql in sqls: + ast = parse_sql(sql) + assert isinstance(ast, Select) + assert len(ast.targets) == 1 + assert isinstance(ast.targets[0], Function) + + def test_select_dquote_alias(self): + sql = """ + select + a as "database" + from information_schema.tables "database" + """ + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Identifier('a', alias=Identifier('database'))], + from_table=Identifier(parts=['information_schema', 'tables'], alias=Identifier('database')), + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_function_with_from(self): + + sql = 'SELECT extract(MONTH FROM dateordered)' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Function( + op='extract', + args=[Identifier('MONTH')], + from_arg=Identifier('dateordered') + )], + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + +class TestOperationsMindsdb: + def test_select_binary_operations(self): + for op in ['not like',]: + sql = f'SELECT column1 {op.upper()} column2 FROM tab' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[BinaryOperation(op=op, + args=( + Identifier.from_path_str('column1'), Identifier.from_path_str('column2') + )), + ], + from_table=Identifier.from_path_str('tab') + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_function_with_namespace(self): + + sql = 'SELECT engine.myfunc(1, 2)' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Function( + op='myfunc', + args=[Constant(1), Constant(2)], + namespace='engine' + )], + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_exits(self): + + for exist in [True, False]: + prefix = '' + cls = Exists + if not exist: + prefix = 'not' + cls = NotExists + + sql = f''' + select * from db.orders + where + orderdate < '1993-05-01' + and {prefix} exists ( + select * from db.item + where l_orderkey = o_orderkey + ) + group by max(orderpriority) + ''' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Star()], + from_table=Identifier('db.orders'), + where=BinaryOperation(op='and', args=[ + BinaryOperation(op='<', args=[Identifier('orderdate'), Constant('1993-05-01')]), + cls(Select( + targets=[Star()], + from_table=Identifier('db.item'), + where=BinaryOperation(op='=', args=[ + Identifier('l_orderkey'), Identifier('o_orderkey') + ]) + )) + ]), + group_by=[Function(op='max', args=[Identifier('orderpriority')])], + ) + + assert str(ast) == str(expected_ast) diff --git a/tests/test_base_sql/test_select_structure.py b/tests/test_base_sql/test_select_structure.py new file mode 100644 index 0000000..810dff6 --- /dev/null +++ b/tests/test_base_sql/test_select_structure.py @@ -0,0 +1,1203 @@ +import itertools +import pytest +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.exceptions import ParsingException +from mindsdb_sql_parser.utils import JoinType + + +class TestSelectStructure: + def test_no_select(self): + query = "" + with pytest.raises(ParsingException): + parse_sql(query) + + def test_select_number(self): + for value in [1, 1.0]: + sql = f'SELECT {value}' + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert len(ast.targets) == 1 + assert isinstance(ast.targets[0], Constant) + assert ast.targets[0].value == value + assert str(ast).lower() == sql.lower() + + def test_select_string(self): + sql = f"SELECT 'string'" + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert len(ast.targets) == 1 + assert isinstance(ast.targets[0], Constant) + assert ast.targets[0].value == 'string' + assert str(ast) == sql + + def test_select_identifier(self): + sql = f'SELECT column' + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert len(ast.targets) == 1 + assert isinstance(ast.targets[0], Identifier) + assert str(ast.targets[0]) == 'column' + assert str(ast).lower() == sql.lower() + + def test_select_identifier_with_dashes(self): + sql = f'SELECT `column-with-dashes`' + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert len(ast.targets) == 1 + assert isinstance(ast.targets[0], Identifier) + assert ast.targets[0].parts == ['column-with-dashes'] + assert str(ast.targets[0]) == '`column-with-dashes`' + assert str(ast).lower() == sql.lower() + + def test_select_identifier_alias(self): + sql_queries = ['SELECT column AS column_alias', + 'SELECT column column_alias'] + for sql in sql_queries: + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert len(ast.targets) == 1 + assert isinstance(ast.targets[0], Identifier) + assert ast.targets[0].parts == ['column'] + assert ast.targets[0].alias.parts[0] == 'column_alias' + assert str(ast).lower().replace('as ', '') == sql.lower().replace('as ', '') + + + + def test_select_identifier_alias_complex(self): + sql = f'SELECT column AS `column alias spaces`' + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert len(ast.targets) == 1 + assert isinstance(ast.targets[0], Identifier) + assert ast.targets[0].parts == ['column'] + assert ast.targets[0].alias.parts[0] == 'column alias spaces' + assert str(ast).lower() == sql.lower() + + def test_select_multiple_identifiers(self): + sql = f'SELECT column1, column2' + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert len(ast.targets) == 2 + assert isinstance(ast.targets[0], Identifier) + assert ast.targets[0].parts[0] == 'column1' + assert isinstance(ast.targets[1], Identifier) + assert ast.targets[1].parts[0] == 'column2' + assert str(ast).lower() == sql.lower() + + def test_select_from_table(self): + sql = f'SELECT column FROM tab' + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert len(ast.targets) == 1 + assert isinstance(ast.targets[0], Identifier) + assert ast.targets[0].parts[0] == 'column' + + assert isinstance(ast.from_table, Identifier) + assert ast.from_table.parts[0] == 'tab' + + assert str(ast).lower() == sql.lower() + + def test_select_from_table_long(self): + query = "SELECT 1 FROM integration.database.schema.tab" + expected_ast = Select( + targets=[Constant(1)], + from_table=Identifier(parts=['integration', 'database', 'schema', 'tab']) + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_select_distinct(self): + sql = """SELECT DISTINCT column1 FROM t1""" + assert str(parse_sql(sql)) == sql + assert parse_sql(sql).distinct + + def test_select_multiple_from_table(self): + sql = f'SELECT column1, column2, 1 AS renamed_constant FROM tab' + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert len(ast.targets) == 3 + assert isinstance(ast.targets[0], Identifier) + assert ast.targets[0].parts[0] == 'column1' + assert ast.targets[1].parts[0] == 'column2' + assert ast.targets[2].value == 1 + assert ast.targets[2].alias.parts[0] == 'renamed_constant' + + assert isinstance(ast.from_table, Identifier) + assert ast.from_table.parts[0] == 'tab' + + assert str(ast).lower() == sql.lower() + + def test_select_from_elaborate(self): + query = """SELECT *, column1, column1 AS aliased, column1 + column2 FROM t1""" + + assert str(parse_sql(query)) == query + assert str(parse_sql(query)) == str(Select(targets=[Star(), + Identifier(parts=["column1"]), + Identifier(parts=["column1"], alias=Identifier('aliased')), + BinaryOperation(op="+", + args=(Identifier(parts=['column1']), + Identifier(parts=['column2'])) + ) + ], + from_table=Identifier(parts=['t1']))) + + def test_select_from_aliased(self): + sql_queries = ["SELECT * FROM t1 AS table_alias", "SELECT * FROM t1 table_alias"] + expected_ast = Select(targets=[Star()], + from_table=Identifier(parts=['t1'], alias=Identifier('table_alias'))) + for query in sql_queries: + assert parse_sql(query) == expected_ast + + def test_from_table_raises_duplicate(self): + sql = f'SELECT column FROM tab FROM tab' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_select_where(self): + sql = f'SELECT column FROM tab WHERE column != 1' + ast = parse_sql(sql) + assert isinstance(ast, Select) + assert len(ast.targets) == 1 + assert isinstance(ast.targets[0], Identifier) + assert ast.targets[0].parts[0] == 'column' + + assert isinstance(ast.from_table, Identifier) + assert ast.from_table.parts[0] == 'tab' + + assert isinstance(ast.where, BinaryOperation) + assert ast.where.op == '!=' + + assert str(ast).lower() == sql.lower() + + def test_select_where_constants(self): + sql = f'SELECT column FROM pred WHERE 1 = 0' + ast = parse_sql(sql) + expected_ast = Select(targets=[Identifier('column')], + from_table=Identifier('pred'), + where=BinaryOperation(op="=", + args=[Constant(1), Constant(0)])) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_from_where_elaborate(self): + query = """SELECT column1, column2 FROM t1 WHERE column1 = 1""" + + assert str(parse_sql(query)) == query + + assert str(parse_sql(query)) == str(Select(targets=[Identifier(parts=["column1"]), Identifier(parts=["column2"])], + from_table=Identifier(parts=['t1']), + where=BinaryOperation(op="=", + args=(Identifier(parts=['column1']), Constant(1)) + ))) + + query = """SELECT column1, column2 FROM t1 WHERE column1 = \'1\'""" + + assert str(parse_sql(query)) == query + + assert str(parse_sql(query)) == str(Select(targets=[Identifier(parts=["column1"]), Identifier(parts=["column2"])], + from_table=Identifier(parts=['t1']), + where=BinaryOperation(op="=", + args=(Identifier(parts=['column1']), Constant("1")) + ))) + + def test_select_from_where_elaborate_lowercase(self): + sql = """select column1, column2 from t1 where column1 = 1""" + assert str(parse_sql(sql)) == str(Select(targets=[Identifier(parts=["column1"]), Identifier(parts=["column2"])], + from_table=Identifier(parts=['t1']), + where=BinaryOperation(op="=", + args=(Identifier(parts=['column1']), Constant(1)) + ))) + + + def test_where_raises_nofrom(self): + sql = f'SELECT column WHERE column != 1' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_where_raises_duplicate(self): + sql = f'SELECT column FROM tab WHERE column != 1 WHERE column > 1' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_where_raises_as(self): + sql = f'SELECT column FROM tab WHERE column != 1 AS somealias' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_select_where_and(self): + sql = f'SELECT column FROM tab WHERE column != 1 and column > 10' + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert len(ast.targets) == 1 + assert isinstance(ast.targets[0], Identifier) + assert ast.targets[0].parts[0] == 'column' + + assert isinstance(ast.from_table, Identifier) + assert ast.from_table.parts[0] == 'tab' + + assert isinstance(ast.where, BinaryOperation) + assert ast.where.op == 'and' + + assert isinstance(ast.where.args[0], BinaryOperation) + assert ast.where.args[0].op == '!=' + assert isinstance(ast.where.args[1], BinaryOperation) + assert ast.where.args[1].op == '>' + + def test_select_where_must_be_an_op(self): + sql = f'SELECT column FROM tab WHERE column' + + with pytest.raises(ParsingException) as excinfo: + ast = parse_sql(sql) + + assert "WHERE must contain an operation that evaluates to a boolean" in str(excinfo.value) + + def test_select_group_by(self): + sql = f'SELECT column FROM tab WHERE column != 1 GROUP BY column1' + ast = parse_sql(sql) + assert str(ast).lower() == sql.lower() + + sql = f'SELECT column FROM tab WHERE column != 1 GROUP BY column1, column2' + ast = parse_sql(sql) + + assert isinstance(ast, Select) + assert len(ast.targets) == 1 + assert isinstance(ast.targets[0], Identifier) + assert ast.targets[0].parts[0] == 'column' + + assert isinstance(ast.from_table, Identifier) + assert ast.from_table.parts[0] == 'tab' + + assert isinstance(ast.where, BinaryOperation) + assert ast.where.op == '!=' + + assert isinstance(ast.group_by, list) + assert isinstance(ast.group_by[0], Identifier) + assert ast.group_by[0].parts[0] == 'column1' + assert isinstance(ast.group_by[1], Identifier) + assert ast.group_by[1].parts[0] == 'column2' + + assert str(ast).lower() == sql.lower() + + def test_select_group_by_elaborate(self): + query = """SELECT column1, column2, sum(column3) AS total FROM t1 GROUP BY column1, column2""" + + assert str(parse_sql(query)) == query + + assert str(parse_sql(query)) == str(Select(targets=[Identifier(parts=["column1"]), + Identifier(parts=["column2"]), + Function(op="sum", + args=[Identifier(parts=["column3"])], + alias=Identifier('total'))], + from_table=Identifier(parts=['t1']), + group_by=[Identifier(parts=["column1"]), Identifier(parts=["column2"])])) + + def test_group_by_raises_duplicate(self): + sql = f'SELECT column FROM tab GROUP BY col GROUP BY col' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_select_having(self): + sql = f'SELECT column FROM tab WHERE column != 1 GROUP BY column1' + ast = parse_sql(sql) + assert str(ast).lower() == sql.lower() + + sql = f'SELECT column FROM tab WHERE column != 1 GROUP BY column1, column2 HAVING column1 > 10' + ast = parse_sql(sql) + + assert isinstance(ast, Select) + + assert isinstance(ast.having, BinaryOperation) + assert isinstance(ast.having.args[0], Identifier) + assert ast.having.args[0].parts[0] == 'column1' + assert ast.having.args[1].value == 10 + + assert str(ast).lower() == sql.lower() + + def test_select_group_by_having_elaborate(self): + sql = """SELECT column1 FROM t1 GROUP BY column1 HAVING column1 != 1""" + assert str(parse_sql(sql)) == sql + + def test_select_order_by_elaborate(self): + sql = """SELECT * FROM t1 ORDER BY column1 ASC, column2, column3 DESC NULLS FIRST""" + ast = parse_sql(sql) + expected_ast = Select(targets=[Star()], + from_table=Identifier(parts=['t1']), + order_by=[ + OrderBy(Identifier(parts=['column1']), direction='ASC'), + OrderBy(Identifier(parts=['column2'])), + OrderBy(Identifier(parts=['column3']), direction='DESC', + nulls='NULLS FIRST')], + ) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_aliases_order_by(self): + sql = "select max(name) as `max(name)` from tbl order by `max(name)`" + + ast = parse_sql(sql) + + expected_ast = Select(targets=[Function('max', args=[Identifier('name')], alias=Identifier('max(name)'))], + from_table=Identifier('tbl'), + order_by=[OrderBy(Identifier('max(name)'))]) + + assert ast.to_tree() == expected_ast.to_tree() + + def test_select_limit_offset_elaborate(self): + sql = """SELECT * FROM t1 LIMIT 1 OFFSET 2""" + ast = parse_sql(sql) + expected_ast = Select(targets=[Star()], + from_table=Identifier(parts=['t1']), + limit=Constant(1), + offset=Constant(2)) + + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_limit_two_arguments(self): + sql = """SELECT * FROM t1 LIMIT 2, 1""" + ast = parse_sql(sql) + expected_ast = Select(targets=[Star()], + from_table=Identifier(parts=['t1']), + limit=Constant(1), + offset=Constant(2)) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_limit_two_arguments_and_offset_error(self): + sql = """SELECT * FROM t1 LIMIT 2, 1 OFFSET 2""" + with pytest.raises(ParsingException): + parse_sql(sql) + + def test_having_raises_duplicate(self): + sql = f'SELECT column FROM tab GROUP BY col HAVING col > 1 HAVING col > 1' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_select_order_by(self): + sql = f'SELECT column1 FROM tab ORDER BY column2' + ast = parse_sql(sql) + assert str(ast).lower() == sql.lower() + + assert len(ast.order_by) == 1 + assert isinstance(ast.order_by[0], OrderBy) + assert isinstance(ast.order_by[0].field, Identifier) + assert ast.order_by[0].field.parts[0] == 'column2' + assert ast.order_by[0].direction == 'default' + + sql = f'SELECT column1 FROM tab ORDER BY column2, column3 ASC, column4 DESC' + ast = parse_sql(sql) + assert str(ast).lower() == sql.lower() + + assert len(ast.order_by) == 3 + + assert isinstance(ast.order_by[0], OrderBy) + assert isinstance(ast.order_by[0].field, Identifier) + assert ast.order_by[0].field.parts[0] == 'column2' + assert ast.order_by[0].direction == 'default' + + assert isinstance(ast.order_by[1], OrderBy) + assert isinstance(ast.order_by[1].field, Identifier) + assert ast.order_by[1].field.parts[0] == 'column3' + assert ast.order_by[1].direction == 'ASC' + + assert isinstance(ast.order_by[2], OrderBy) + assert isinstance(ast.order_by[2].field, Identifier) + assert ast.order_by[2].field.parts[0] == 'column4' + assert ast.order_by[2].direction == 'DESC' + + def test_order_by_raises_duplicate(self): + sql = f'SELECT column FROM tab ORDER BY col1 ORDER BY col1' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_select_limit_offset(self): + sql = f'SELECT column FROM tab LIMIT 5 OFFSET 3' + ast = parse_sql(sql) + assert str(ast).lower() == sql.lower() + + assert ast.limit == Constant(value=5) + assert ast.offset == Constant(value=3) + + def test_select_limit_offset_raises_nonint(self): + sql = f'SELECT column FROM tab OFFSET 3.0' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + sql = "SELECT column FROM tab LIMIT \'string\'" + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_select_limit_offset_raises_wrong_order(self): + sql = f'SELECT column FROM tab OFFSET 3 LIMIT 5 ' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_limit_raises_duplicate(self): + sql = f'SELECT column FROM tab LIMIT 1 LIMIT 1' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_offset_raises_duplicate(self): + sql = f'SELECT column FROM tab OFFSET 1 OFFSET 1' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_limit_raises_before_order_by(self): + sql = f'SELECT column FROM tab LIMIT 1 ORDER BY column ASC' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_offset_raises_before_order_by(self): + sql = f'SELECT column FROM tab OFFSET 1 ORDER BY column ASC' + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_select_order(self): + components = ['FROM tab', + 'WHERE column = 1', + 'GROUP BY column', + 'HAVING column != 2', + 'ORDER BY column ASC', + 'LIMIT 1', + 'OFFSET 1'] + + good_sql = 'SELECT column ' + '\n'.join(components) + ast = parse_sql(good_sql) + assert ast + + for perm in itertools.permutations(components): + bad_sql = 'SELECT column ' + '\n'.join(perm) + if bad_sql == good_sql: + continue + + with pytest.raises(ParsingException) as excinfo: + ast = parse_sql(bad_sql) + + def test_select_from_inner_join(self): + sql = """SELECT * FROM t1 INNER JOIN t2 ON t1.x1 = t2.x2 and t1.x2 = t2.x2""" + + expected_ast = Select(targets=[Star()], + from_table=Join(join_type=JoinType.INNER_JOIN, + left=Identifier(parts=['t1']), + right=Identifier(parts=['t2']), + condition= + BinaryOperation(op='and', + args=[ + BinaryOperation(op='=', + args=( + Identifier( + parts=['t1','x1']), + Identifier( + parts=['t2','x2']))), + BinaryOperation(op='=', + args=( + Identifier( + parts=['t1','x2']), + Identifier( + parts=['t2','x2']))), + ]) + + )) + ast = parse_sql(sql) + + assert ast == expected_ast + + def test_select_from_implicit_join(self): + sql = """SELECT * FROM t1, t2""" + + expected_ast = Select(targets=[Star()], + from_table=Join(left=Identifier(parts=['t1']), + right=Identifier(parts=['t2']), + join_type=JoinType.INNER_JOIN, + implicit=True, + condition=None)) + ast = parse_sql(sql) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_from_different_join_types(self): + join_types = ['JOIN', 'INNER JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'FULL JOIN'] + for join in join_types: + sql = f"""SELECT * FROM t1 {join} t2 ON t1.x1 = t2.x2""" + expected_ast = Select(targets=[Star()], + from_table=Join(join_type=join, + left=Identifier(parts=['t1']), + right=Identifier(parts=['t2']), + condition= + BinaryOperation(op='=', + args=( + Identifier( + parts=['t1','x1']), + Identifier( + parts=['t2','x2']))), + + )) + + ast = parse_sql(sql) + assert ast == expected_ast + + def test_select_from_subquery(self): + sql = f"""SELECT * FROM (SELECT column1 FROM t1) AS sub""" + expected_ast = Select(targets=[Star()], + from_table=Select(targets=[Identifier(parts=['column1'])], + from_table=Identifier(parts=['t1']), + alias=Identifier('sub'), + parentheses=True)) + ast = parse_sql(sql) + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert ast == expected_ast + + sql = f"""SELECT * FROM (SELECT column1 FROM t1)""" + expected_ast = Select(targets=[Star()], + from_table=Select(targets=[Identifier(parts=['column1'])], + from_table=Identifier(parts=['t1']), + parentheses=True)) + ast = parse_sql(sql) + assert str(ast).lower() == sql.lower() + assert ast == expected_ast + + def test_select_subquery_target(self): + sql = f"""SELECT *, (SELECT 1) FROM t1""" + ast = parse_sql(sql) + expected_ast = Select(targets=[Star(), Select(targets=[Constant(1)], parentheses=True)], + from_table=Identifier(parts=['t1'])) + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = f"""SELECT *, (SELECT 1) AS ones FROM t1""" + ast = parse_sql(sql) + expected_ast = Select(targets=[Star(), Select(targets=[Constant(1)], alias=Identifier('ones'), parentheses=True)], + from_table=Identifier(parts=['t1'])) + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_subquery_where(self): + sql = f"""SELECT * FROM tab1 WHERE column1 in (SELECT column2 FROM t2)""" + ast = parse_sql(sql) + expected_ast = Select(targets=[Star()], + from_table=Identifier(parts=['tab1']), + where=BinaryOperation(op='in', + args=( + Identifier(parts=['column1']), + Select(targets=[Identifier(parts=['column2'])], + from_table=Identifier(parts=['t2']), + parentheses=True) + ))) + assert str(ast).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_type_cast(self): + sql = f"""SELECT CAST(4 AS int64) AS result""" + ast = parse_sql(sql) + expected_ast = Select(targets=[TypeCast(type_name='int64', arg=Constant(4), alias=Identifier('result'))]) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = f"""SELECT CAST(column1 AS float) AS result""" + ast = parse_sql(sql) + expected_ast = Select(targets=[TypeCast(type_name='float', arg=Identifier(parts=['column1']), alias=Identifier('result'))]) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = f"""SELECT CAST((column1 + column2) AS float) AS result""" + ast = parse_sql(sql) + expected_ast = Select(targets=[TypeCast(type_name='float', arg=BinaryOperation(op='+', parentheses=True, args=[ + Identifier(parts=['column1']), Identifier(parts=['column2'])]), alias=Identifier('result'))]) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = f"""SELECT CAST(a AS CHAR(10))""" + ast = parse_sql(sql) + expected_ast = Select(targets=[ + TypeCast(type_name='CHAR', arg=Identifier('a'), precision=[10]) + ]) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = f"""SELECT CAST(a AS DECIMAL(10, 1))""" + ast = parse_sql(sql) + expected_ast = Select(targets=[ + TypeCast(type_name='DECIMAL', arg=Identifier('a'), precision=[10, 1]) + ]) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_in_tuple(self): + sql = "SELECT col FROM tab WHERE col in (1, 2)" + ast = parse_sql(sql) + expected_ast = Select(targets=[Identifier(parts=['col'])], + from_table=Identifier(parts=['tab']), + where=BinaryOperation(op='in', + args=( + Identifier(parts=['col']), + Tuple(items=[Constant(1), Constant(2)]) + ))) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_count_distinct(self): + sql = "SELECT COUNT(DISTINCT survived) AS uniq_survived FROM titanic" + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Function(op='COUNT', distinct=True, + args=(Identifier(parts=['survived']),), alias=Identifier('uniq_survived'))], + from_table=Identifier(parts=['titanic']) + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_where_not_order(self): + sql = "SELECT col1 FROM tab WHERE NOT col1 = \'FAMILY\'" + ast = parse_sql(sql) + + expected_ast = Select(targets=[Identifier(parts=['col1'])], + from_table=Identifier(parts=['tab']), + where=UnaryOperation(op='NOT', + args=( + BinaryOperation(op='=', + args=(Identifier(parts=['col1']), Constant('FAMILY'))), + ) + ) + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_backticks(self): + sql = "SELECT `name`, `status` FROM `mindsdb`.`wow stuff predictors`.`even-dashes-work`.`nice`" + ast = parse_sql(sql) + + expected_ast = Select(targets=[Identifier(parts=['name']), Identifier(parts=['status'])], + from_table=Identifier(parts=['mindsdb', 'wow stuff predictors', 'even-dashes-work', 'nice']), + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_partial_backticks(self): + sql = "SELECT `integration`.`some table`.column" + ast = parse_sql(sql) + + expected_ast = Select(targets=[Identifier(parts=['integration', 'some table', 'column']),],) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_backticks_in_str(self): + sql = "SELECT `my column name` FROM tab WHERE `other column name` = 'bla bla ``` bla'" + ast = parse_sql(sql) + + expected_ast = Select(targets=[Identifier(parts=['my column name'])], + from_table=Identifier(parts=['tab']), + where=BinaryOperation(op='=', args=( + Identifier(parts=['other column name']), + Constant('bla bla ``` bla') + ) + )) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_parameter(self): + sql = "SELECT ? = ? FROM ?" + ast = parse_sql(sql) + + expected_ast = Select(targets=[BinaryOperation(op='=', args=(Parameter('?'), Parameter('?')))], + from_table=Parameter('?'), + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_from_tables(self): + sql = "SELECT * FROM tables" + ast = parse_sql(sql) + + expected_ast = Select(targets=[Star()], + from_table=Identifier('tables')) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_tricky_tables_case(self): + sql = "SELECT TABLES.table_name AS Tables_in_mindsdb FROM TABLES WHERE TABLES.table_schema = 'MINDSDB' AND TABLES.table_type = 'BASE TABLE'" + ast = parse_sql(sql) + + expected_ast = Select(targets=[Identifier('TABLES.table_name', alias=Identifier('Tables_in_mindsdb'))], + from_table=Identifier('TABLES'), + where=BinaryOperation('and', args=[ + BinaryOperation('=', args=[Identifier('TABLES.table_schema'), Constant('MINDSDB')]), + BinaryOperation('=', args=[Identifier('TABLES.table_type'), Constant('BASE TABLE')]), + ])) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_double_aliased_table(self): + sql = "select * from table1 zzzzz alias1" + + with pytest.raises(ParsingException): + parse_sql(sql) + + def test_window_function(self): + query = "select SUM(col0) OVER (PARTITION BY col1 order by col2) as al from table1 " + expected_ast = Select( + targets=[ + WindowFunction( + function=Function(op='sum', args=[Identifier('col0')]), + partition=[Identifier('col1')], + order_by=[OrderBy(field=Identifier('col2'))], + alias=Identifier('al') + ) + ], + from_table=Identifier('table1') + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # no partition + query = "select SUM(col0) OVER (order by col2) from table1 " + expected_ast = Select( + targets=[ + WindowFunction( + function=Function(op='sum', args=[Identifier('col0')]), + order_by=[OrderBy(field=Identifier('col2'))], + ) + ], + from_table=Identifier('table1') + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # no order by + query = "select SUM(col0) OVER (PARTITION BY col1) from table1 " + expected_ast = Select( + targets=[ + WindowFunction( + function=Function(op='sum', args=[Identifier('col0')]), + partition=[Identifier('col1')], + ) + ], + from_table=Identifier('table1') + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # just over() + query = "select SUM(col0) OVER () from table1 " + expected_ast = Select( + targets=[ + WindowFunction( + function=Function(op='sum', args=[Identifier('col0')]), + ) + ], + from_table=Identifier('table1') + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_is_not_precedence(self): + query = "select * from t1 where a is not null and b = c" + expected_ast = Select( + targets=[Star()], + from_table=Identifier('t1'), + where=BinaryOperation(op='and', args=[ + BinaryOperation(op='is not', args=[ + Identifier('a'), + NullConstant() + ]), + BinaryOperation(op='=', args=[ + Identifier('b'), + Identifier('c') + ]), + ]) + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + +class TestSelectStructureNoSqlite: + def test_select_from_plugins(self): + query = "select * from information_schema.plugins" + expected_ast = Select( + targets=[Star()], + from_table=Identifier(parts=['information_schema', 'plugins']) + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + query = "select * from plugins" + expected_ast = Select( + targets=[Star()], + from_table=Identifier(parts=['plugins']) + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + + def test_type_convert(self): + sql = f"""SELECT CONVERT(column1, float)""" + ast = parse_sql(sql) + expected_ast = Select(targets=[TypeCast(type_name='float', arg=Identifier(parts=['column1']))]) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + sql = f"""SELECT CONVERT((column1 + column2) USING float)""" + ast = parse_sql(sql) + expected_ast = Select(targets=[TypeCast(type_name='float', arg=BinaryOperation(op='+', parentheses=True, args=[ + Identifier(parts=['column1']), Identifier(parts=['column2'])]))]) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_for_update(self): + sql = f'SELECT * FROM tab for update' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Star()], + from_table=Identifier(parts=['tab']), + mode='FOR UPDATE' + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_keywords(self): + sql = f'SELECT COLLATION_NAME AS Collation, CHARACTER_SET_NAME AS Charset,\ + ID AS Id, IS_COMPILED AS Compiled, PLUGINS, MASTER, STATUS, ONLY\ + FROM INFORMATION_SCHEMA.COLLATIONS' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Identifier('COLLATION_NAME', alias=Identifier('Collation')), + Identifier('CHARACTER_SET_NAME', alias=Identifier('Charset')), + Identifier('ID', alias=Identifier('Id')), + Identifier('IS_COMPILED', alias=Identifier('Compiled')), + Identifier('PLUGINS'), + Identifier('MASTER'), + Identifier('STATUS'), + Identifier('ONLY'), + ], + from_table=Identifier('INFORMATION_SCHEMA.COLLATIONS') + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_table_star(self): + sql = f'select *, t.* From table1 ' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Star(), + Identifier(parts=['t', Star()]), + ], + from_table=Identifier('table1') + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_from_select(self): + sql = f'select * from (select a from tab1) as sub ' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Star(), + ], + from_table=Select( + alias=Identifier('sub'), + parentheses=True, + targets=[ + Identifier('a') + ], + from_table=Identifier('tab1') + ) + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_function_star(self): + sql = f'select count(*) from tab1' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Function(op='count', args=[ + Star() + ]) + ], + from_table=Identifier('tab1') + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + +class TestMindsdb: + + def test_case(self): + sql = f'''SELECT + CASE + WHEN R.DELETE_RULE = 'CASCADE' THEN 0 + WHEN R.DELETE_RULE = 'SET NULL' THEN 2 + ELSE 3 + END AS DELETE_RULE, + sum( + CASE + WHEN 1 = 1 THEN 1 + END + ) + FROM INFORMATION_SCHEMA.COLLATIONS''' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Case( + rules=[ + [ + BinaryOperation(op='=', args=[ + Identifier('R.DELETE_RULE'), + Constant('CASCADE') + ]), + Constant(0) + ], + [ + BinaryOperation(op='=', args=[ + Identifier('R.DELETE_RULE'), + Constant('SET NULL') + ]), + Constant(2) + ] + ], + default=Constant(3), + alias=Identifier('DELETE_RULE') + ), + Function( + op='sum', + args=[ + Case( + rules=[ + [ + BinaryOperation(op='=', args=[Constant(1), Constant(1)]), + Constant(1) + ], + ], + ) + ] + ) + ], + from_table=Identifier('INFORMATION_SCHEMA.COLLATIONS') + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_case_simple_form(self): + sql = f'''SELECT + CASE R.DELETE_RULE + WHEN 'CASCADE' THEN 0 + WHEN 'SET NULL' THEN 2 + ELSE 3 + END AS DELETE_RULE + FROM COLLATIONS''' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Case( + arg=Identifier('R.DELETE_RULE'), + rules=[ + [ + Constant('CASCADE'), + Constant(0) + ], + [ + Constant('SET NULL'), + Constant(2) + ] + ], + default=Constant(3), + alias=Identifier('DELETE_RULE') + ) + ], + from_table=Identifier('COLLATIONS') + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_select_left(self): + sql = f'select left(a, 1) from tab1' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Function(op='left', args=[ + Identifier('a'), + Constant(1) + ]) + ], + from_table=Identifier('tab1') + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_not_in_precedence(self): + query = "select * from t1 where a = 1 and b not in (1, 2)" + expected_ast = Select( + targets=[Star()], + from_table=Identifier('t1'), + where=BinaryOperation(op='and', args=[ + BinaryOperation(op='=', args=[ + Identifier('a'), + Constant(1) + ]), + BinaryOperation(op='not in', args=[ + Identifier('b'), + Tuple([Constant(1), Constant(2)]) + ]), + ]) + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_select_from_subquery_with_columns(self): + sql = f""" + SELECT * FROM + (SELECT col1, col2 b + FROM t1) + AS sub (c1, c2) + """ + expected_ast = Select( + targets=[Star()], + from_table=Select( + targets=[ + Identifier(parts=['col1'], alias=Identifier('c1')), + Identifier(parts=['col2'], alias=Identifier('c2')) + ], + from_table=Identifier(parts=['t1']), + alias=Identifier('sub'), + parentheses=True + ) + ) + ast = parse_sql(sql) + assert str(ast) == str(expected_ast) + + def test_substring(self): + expected_ast = Select( + targets=[Function( + op='substring', + args=[Identifier('phone'), Constant(1), Constant(2)] + )], + ) + + for sql in ( + 'SELECT substring(phone from 1 for 2)', + 'SELECT substring(phone, 1, 2)' + ): + + ast = parse_sql(sql) + assert str(ast) == str(expected_ast) + + def test_alternative_casting(self): + + # int + expected_ast = Select(targets=[ + TypeCast(type_name='CHAR', arg=Constant('1998')), + TypeCast(type_name='CHAR', arg=Identifier('col1'), alias=Identifier('col2')) + ]) + + sql = f"SELECT '1998'::CHAR, col1::CHAR col2" + ast = parse_sql(sql) + assert str(ast) == str(expected_ast) + + # date + expected_ast = Select( + targets=[ + TypeCast(type_name='DATE', arg=Constant('1998-12-01')), + BinaryOperation(op='+', args=[ + Identifier('col0'), + TypeCast(type_name='DATE', arg=Identifier('col1'), alias=Identifier('col2')), + ]) + ], + from_table=Identifier('t1'), + where=BinaryOperation(op='>', args=[ + Identifier('col0'), + TypeCast(type_name='DATE', arg=Identifier('col1')), + ]) + ) + + sql = f"SELECT '1998-12-01'::DATE, col0 + col1::DATE col2 from t1 where col0 > col1::DATE" + ast = parse_sql(sql) + assert str(ast) == str(expected_ast) + + # + expected_ast = Select(targets=[ + TypeCast(type_name='DATE', arg=Constant('1998-12-01')), + ]) + + sql = f"SELECT DATE '1998-12-01'" + ast = parse_sql(sql) + assert str(ast) == str(expected_ast) + + def test_table_double_quote(self): + expected_ast = Select( + targets=[Identifier('account_id')], + from_table=Identifier(parts=['order']) + ) + + sql = 'select account_id from "order"' + + ast = parse_sql(sql) + assert str(ast) == str(expected_ast) + + def test_window_function_mindsdb(self): + + # modifier + query = "select SUM(col0) OVER (partition by col1 order by col2 rows between unbounded preceding and current row) from table1 " + expected_ast = Select( + targets=[ + WindowFunction( + function=Function(op='sum', args=[Identifier('col0')]), + partition=[Identifier('col1')], + order_by=[OrderBy(field=Identifier('col2'))], + modifier='rows BETWEEN unbounded preceding AND current row' + ) + ], + from_table=Identifier('table1') + ) + ast = parse_sql(query) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + diff --git a/tests/test_base_sql/test_show.py b/tests/test_base_sql/test_show.py new file mode 100644 index 0000000..877d914 --- /dev/null +++ b/tests/test_base_sql/test_show.py @@ -0,0 +1,416 @@ +import pytest + +from mindsdb_sql_parser import parse_sql, ParsingException +from mindsdb_sql_parser.ast import * + + +class TestShow: + def test_show_category(self): + categories = ['SCHEMAS', + 'DATABASES', + 'TABLES', + 'TABLES', + 'VARIABLES', + 'PLUGINS', + 'SESSION VARIABLES', + 'SESSION STATUS', + 'GLOBAL VARIABLES', + 'PROCEDURE STATUS', + 'FUNCTION STATUS', + 'CREATE TABLE', + 'WARNINGS', + 'ENGINES', + 'CHARSET', + 'CHARACTER SET', + 'COLLATION', + 'TABLE STATUS', + 'STATUS'] + for cat in categories: + sql = f"SHOW {cat}" + ast = parse_sql(sql) + expected_ast = Show(category=cat) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_show_unknown_condition_error(self): + sql = "SHOW databases WITH" + with pytest.raises(ParsingException): + parse_sql(sql) + + def test_show_tables_from_db(self): + sql = "SHOW tables from db" + ast = parse_sql(sql) + expected_ast = Show(category='tables', from_table=Identifier('db')) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_show_function_status(self): + sql = "show function status where Db = 'MINDSDB' AND Name LIKE '%'" + ast = parse_sql(sql) + expected_ast = Show(category='function status', + where=BinaryOperation('and', args=[ + BinaryOperation('=', args=[Identifier('Db'), Constant('MINDSDB')]), + BinaryOperation('like', args=[Identifier('Name'), Constant('%')]) + ]), + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + + def test_show_character_set(self): + sql = "show character set where charset = 'utf8mb4'" + ast = parse_sql(sql) + expected_ast = Show(category='character set', + where=BinaryOperation('=', args=[Identifier('charset'), Constant('utf8mb4')]), + ) + + # assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_from_where(self): + sql = "SHOW FULL TABLES FROM ttt LIKE 'zzz' WHERE xxx" + ast = parse_sql(sql) + expected_ast = Show( + category='TABLES', + modes=['FULL'], + from_table=Identifier('ttt'), + like='zzz', + where=Identifier('xxx'), + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_full_columns(self): + sql = "SHOW FULL COLUMNS FROM `concrete` FROM `files`" + ast = parse_sql(sql) + expected_ast = Show( + category='COLUMNS', + modes=['FULL'], + from_table=Identifier('files.concrete') + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + + +class TestShowNoSqlite: + def test_category(self): + categories = [ + 'BINARY LOGS', + 'MASTER LOGS', + 'PROCESSLIST', + 'STORAGE ENGINES', + 'PRIVILEGES', + 'MASTER STATUS', + 'PROFILES', + 'REPLICAS', + ] + + for cat in categories: + sql = f"SHOW {cat}" + ast = parse_sql(sql) + expected_ast = Show(category=cat) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + @pytest.mark.parametrize('cat', [ + 'CHARACTER SET', + 'CHARSET', + 'COLLATION', + 'DATABASES', + 'SCHEMAS', + 'FUNCTION STATUS', + 'PROCEDURE STATUS', + 'GLOBAL STATUS', + 'SESSION STATUS', + 'STATUS', + 'GLOBAL VARIABLES', + 'SESSION VARIABLES', + ]) + def test_common_like_where(self, cat): + + sql = f"SHOW {cat} like 'pattern' where a=1" + ast = parse_sql(sql) + expected_ast = Show( + category=cat, + like='pattern', + where=BinaryOperation(op='=', args=[ + Identifier('a'), + Constant(1) + ]) + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_common_like_where_from_in(self): + categories = [ + 'TABLE STATUS', + 'OPEN TABLES', + 'TRIGGERS', + ] + + for cat in categories: + sql = f"SHOW {cat} from tab1 in tab2 like 'pattern' where a=1" + ast = parse_sql(sql) + expected_ast = Show( + category=cat, + like='pattern', + from_table=Identifier('tab1'), + in_table=Identifier('tab2'), + where=BinaryOperation(op='=', args=[ + Identifier('a'), + Constant(1) + ]) + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_common_like_where_from_in_modes(self): + categories = [ + 'TABLES', + ] + modes = [ + ['EXTENDED'], + ['FULL'], + ['EXTENDED', 'FULL'], + ] + + for cat in categories: + for mode in modes: + + sql = f"SHOW {' '.join(mode)} {cat} from tab1 in tab2 like 'pattern' where a=1" + ast = parse_sql(sql) + expected_ast = Show( + category=cat, + like='pattern', + from_table=Identifier('tab1'), + in_table=Identifier('tab2'), + modes=mode, + where=BinaryOperation(op='=', args=[ + Identifier('a'), + Constant(1) + ]) + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_common_like_double_where_from_in_modes(self): + categories = [ + 'COLUMNS', + 'FIELDS', + 'INDEX', + 'INDEXES', + 'KEYS', + ] + modes = [ + ['EXTENDED'], + ['FULL'], + ['EXTENDED', 'FULL'], + ] + for cat in categories: + for mode in modes: + + sql = f"SHOW {' '.join(mode)} {cat} from tab1 from db1 in tab2 in db2 like 'pattern' where a=1" + ast = parse_sql(sql) + expected_ast = Show( + category=cat, + like='pattern', + from_table=Identifier('db1.tab1'), + in_table=Identifier('db2.tab2'), + modes=mode, + where=BinaryOperation(op='=', args=[ + Identifier('a'), + Constant(1) + ]) + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_custom(self): + + for arg in ['FUNCTION', 'PROCEDURE']: + sql = f"SHOW {arg} CODE obj_name" + ast = parse_sql(sql) + expected_ast = Show( + category=f"{arg} CODE", + name='obj_name', + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + for arg in ['SLAVE', 'REPLICA']: + sql = f"SHOW {arg} STATUS FOR CHANNEL channel" + ast = parse_sql(sql) + expected_ast = Show( + category=f"REPLICA STATUS", + name='channel', + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # without channel + sql = f"SHOW {arg} STATUS" + ast = parse_sql(sql) + expected_ast = Show( + category=f"REPLICA STATUS", + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + +class TestShowAdapted: + + def test_show_database_adapted(self): + statement = Select( + targets=[Identifier(parts=["schema_name"], alias=Identifier('Database'))], + from_table=Identifier(parts=['information_schema', 'SCHEMATA']) + ) + sql = statement.get_string() + + statement2 = parse_sql(sql) + + assert statement2.to_tree() == statement.to_tree() + + +class TestMindsdb: + + def test_show_engine(self): + for arg in ['STATUS', 'MUTEX']: + sql = f"SHOW ENGINE engine_name {arg}" + ast = parse_sql(sql) + expected_ast = Show( + category='ENGINE engine_name', + name=arg, + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_show(self): + sql = ''' + show full databases + ''' + statement = parse_sql(sql) + statement2 = Show( + category='databases', + modes=['full'] + ) + + assert statement2.to_tree() == statement.to_tree() + + # --- show models --- + sql = ''' + show models + ''' + statement = parse_sql(sql) + statement2 = Show( + category='models' + ) + assert statement2.to_tree() == statement.to_tree() + + sql = ''' + show models FROM db_name + ''' + statement = parse_sql(sql) + statement2 = Show( + category='models', + from_table=Identifier('db_name') + ) + assert statement2.to_tree() == statement.to_tree() + + sql = ''' + show models LIKE 'pattern' + ''' + statement = parse_sql(sql) + statement2 = Show( + category='models', + like='pattern', + ) + assert statement2.to_tree() == statement.to_tree() + + sql = ''' + show models FROM db_name LIKE 'pattern' WHERE a=1 + ''' + statement = parse_sql(sql) + statement2 = Show( + category='models', + from_table=Identifier('db_name'), + like='pattern', + where=BinaryOperation(op='=', args=[ + Identifier('a'), + Constant(1) + ]) + ) + assert statement2.to_tree() == statement.to_tree() + + sql = ''' + show predictors + ''' + statement = parse_sql(sql) + statement2 = Show( + category='predictors' + ) + assert statement2.to_tree() == statement.to_tree() + + # --- ml_engines --- + sql = ''' + show ML_ENGINES + ''' + statement = parse_sql(sql) + statement2 = Show( + category='ML_ENGINES', + ) + assert statement2.to_tree() == statement.to_tree() + + sql = ''' + show ML_ENGINES LIKE 'pattern' + ''' + statement = parse_sql(sql) + statement2 = Show( + category='ML_ENGINES', + like='pattern', + ) + assert statement2.to_tree() == statement.to_tree() + + sql = ''' + show ML_ENGINES LIKE 'pattern' WHERE a=1 + ''' + statement = parse_sql(sql) + statement2 = Show( + category='ML_ENGINES', + like='pattern', + where=BinaryOperation(op='=', args=[ + Identifier('a'), + Constant(1) + ]) + ) + assert statement2.to_tree() == statement.to_tree() + + # --- handlers --- + sql = ''' + show Handlers + ''' + statement = parse_sql(sql) + statement2 = Show( + category='HANDLERS', + ) + assert statement2.to_tree() == statement.to_tree() + + + diff --git a/tests/test_base_sql/test_union.py b/tests/test_base_sql/test_union.py new file mode 100644 index 0000000..8810d2b --- /dev/null +++ b/tests/test_base_sql/test_union.py @@ -0,0 +1,77 @@ +import pytest +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.exceptions import ParsingException + + +class TestUnion: + def test_single_select_error(self): + sql = "SELECT col FROM tab UNION" + with pytest.raises(ParsingException): + parse_sql(sql) + + def test_union_base(self): + for keyword, cls in {'union': Union, 'intersect': Intersect, 'except': Except}.items(): + sql = f"""SELECT col1 FROM tab1 + {keyword} + SELECT col1 FROM tab2""" + + ast = parse_sql(sql) + expected_ast = cls(unique=True, + left=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab1']), + ), + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab2']), + ), + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_union_all(self): + for keyword, cls in {'union': Union, 'intersect': Intersect, 'except': Except}.items(): + sql = f"""SELECT col1 FROM tab1 + {keyword} ALL + SELECT col1 FROM tab2""" + + ast = parse_sql(sql) + expected_ast = cls(unique=False, + left=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab1']), + ), + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab2']), + ), + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_union_alias(self): + sql = """SELECT * FROM ( + SELECT col1 FROM tab1 + UNION + SELECT col1 FROM tab2 + UNION + SELECT col1 FROM tab3 + ) AS alias""" + + ast = parse_sql(sql) + expected_ast = Select(targets=[Star()], + from_table=Union( + unique=True, + alias=Identifier('alias'), + left=Union( + unique=True, + left=Select( + targets=[Identifier('col1')], + from_table=Identifier(parts=['tab1']),), + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab2']),), + ), + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab3']),), + ) + ) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + diff --git a/tests/test_base_sql/test_update.py b/tests/test_base_sql/test_update.py new file mode 100644 index 0000000..ad618cf --- /dev/null +++ b/tests/test_base_sql/test_update.py @@ -0,0 +1,85 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * + + +class TestUpdate: + + def test_update_simple(self): + sql = "update tbl_name set a=b, c='a', d=2, e=f.g" + + expected_ast = Update( + table=Identifier('tbl_name'), + update_columns={ + 'a': Identifier('b'), + 'c': Constant('a'), + 'd': Constant(2), + 'e': Identifier(parts=['f', 'g']), + }, + ) + + ast = parse_sql(sql) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + sql += ' where a=b or c>1' + + expected_ast.where = BinaryOperation(op='or', args=[ + BinaryOperation(op='=', args=[ + Identifier('a'), + Identifier('b') + ]), + BinaryOperation(op='>', args=[ + Identifier('c'), + Constant(1) + ]) + ]) + + ast = parse_sql(sql) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + +class TestUpdateFromSelect: + + def test_update_simple(self): + sql = """ + update + table2 + set + predicted = df.result + from + ( + select result, prod_id from table1 + USING aaa = "bbb" + ) as df + where + table2.prod_id = df.prod_id + """ + + expected_ast = Update( + table=Identifier('table2'), + update_columns={ + 'predicted': Identifier('df.result') + }, + from_select=Select( + targets=[Identifier('result'), Identifier('prod_id')], + from_table=Identifier('table1'), + using={'aaa': 'bbb'} + ), + from_select_alias=Identifier('df'), + where=BinaryOperation(op='=', args=[ + Identifier('table2.prod_id'), + Identifier('df.prod_id') + ]) + ) + + ast = parse_sql(sql) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + + + diff --git a/tests/test_base_sql/test_use.py b/tests/test_base_sql/test_use.py new file mode 100644 index 0000000..facf5ab --- /dev/null +++ b/tests/test_base_sql/test_use.py @@ -0,0 +1,14 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * + + +class TestUse: + def test_use(self): + sql = "USE my_integration" + ast = parse_sql(sql) + expected_ast = Use(value=Identifier('my_integration')) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + diff --git a/tests/test_mindsdb/__init__.py b/tests/test_mindsdb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_mindsdb/test_agents.py b/tests/test_mindsdb/test_agents.py new file mode 100644 index 0000000..4ad4511 --- /dev/null +++ b/tests/test_mindsdb/test_agents.py @@ -0,0 +1,64 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast.mindsdb import * +from mindsdb_sql_parser.ast import * + + +class TestAgents: + def test_create_agent(self): + sql = ''' + create agent if not exists my_agent + using + model = 'my_model', + skills = ['skill1', 'skill2'] + ''' + ast = parse_sql(sql) + expected_ast = CreateAgent( + name=Identifier('my_agent'), + model='my_model', + params={'skills': ['skill1', 'skill2']}, + if_not_exists=True + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # Parse again after rendering to catch problems with rendering. + ast = parse_sql(str(ast)) + assert str(ast) == str(expected_ast) + + def test_update_agent(self): + sql = ''' + update agent my_agent + set + model = 'new_model', + skills = ['new_skill1', 'new_skill2'] + ''' + ast = parse_sql(sql) + expected_params = { + 'model': 'new_model', + 'skills': ['new_skill1', 'new_skill2'] + } + expected_ast = UpdateAgent( + name=Identifier('my_agent'), + updated_params=expected_params + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # Parse again after rendering to catch problems with rendering. + ast = parse_sql(str(ast)) + assert str(ast) == str(expected_ast) + + def test_drop_agent(self): + sql = ''' + drop agent if exists my_agent + ''' + ast = parse_sql(sql) + expected_ast = DropAgent( + name=Identifier('my_agent'), if_exists=True + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # Parse again after rendering to catch problems with rendering. + ast = parse_sql(str(ast)) + assert str(ast) == str(expected_ast) diff --git a/tests/test_mindsdb/test_chatbots.py b/tests/test_mindsdb/test_chatbots.py new file mode 100644 index 0000000..eba0978 --- /dev/null +++ b/tests/test_mindsdb/test_chatbots.py @@ -0,0 +1,80 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast.mindsdb import * +from mindsdb_sql_parser.ast import * + + +class TestChatbots: + def test_test_create_chatbot(self): + sql = ''' + create chatbot mybot + using + model = 'chat_model', + database ='my_rocket_chat', + agent = 'my_agent' + ''' + ast = parse_sql(sql) + expected_ast = CreateChatBot( + name=Identifier('mybot'), + database=Identifier('my_rocket_chat'), + model=Identifier('chat_model'), + agent=Identifier('my_agent') + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_test_create_chatbot_with_params(self): + sql = ''' + create chatbot mybot + using + model = 'chat_model', + database ='my_rocket_chat', + key = 'value' + ''' + ast = parse_sql(sql) + expected_ast = CreateChatBot( + name=Identifier('mybot'), + database=Identifier('my_rocket_chat'), + model=Identifier('chat_model'), + agent=None, + params={'key': 'value'} + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_update_chatbot(self): + sql = ''' + update chatbot mybot + set + name = 'new_name', + model = 'new_model', + database = 'new_database', + chat_engine = 'new_chat_engine', + is_running = true, + new_param = 'new_value' + ''' + ast = parse_sql(sql) + expected_params = { + 'name': 'new_name', + 'model': 'new_model', + 'database': 'new_database', + 'chat_engine': 'new_chat_engine', + 'is_running': True, + 'new_param': 'new_value' + } + expected_ast = UpdateChatBot( + name=Identifier('mybot'), + updated_params=expected_params + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_drop_chatbot(self): + sql = ''' + drop chatbot mybot + ''' + ast = parse_sql(sql) + expected_ast = DropChatBot( + name=Identifier('mybot'), + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_mindsdb/test_create_integration.py b/tests/test_mindsdb/test_create_integration.py new file mode 100644 index 0000000..753c1ef --- /dev/null +++ b/tests/test_mindsdb/test_create_integration.py @@ -0,0 +1,132 @@ +import pytest + +from mindsdb_sql_parser import parse_sql, ParsingException +from mindsdb_sql_parser.ast import Identifier +from mindsdb_sql_parser.ast.mindsdb import * +from mindsdb_sql_parser.lexer import MindsDBLexer + + +class TestCreateDatabase: + def test_create_database_lexer(self): + sql = "CREATE DATABASE IF NOT EXISTS db WITH ENGINE = 'mysql', PARAMETERS = {\"user\": \"admin\", \"password\": \"admin\"}" + tokens = list(MindsDBLexer().tokenize(sql)) + assert tokens[0].type == 'CREATE' + assert tokens[1].type == 'DATABASE' + assert tokens[2].type == 'IF' + assert tokens[3].type == 'NOT_EXISTS' + assert tokens[4].type == 'ID' + assert tokens[5].type == 'WITH' + assert tokens[6].type == 'ENGINE' + assert tokens[7].type == 'EQUALS' + assert tokens[8].type == 'QUOTE_STRING' + assert tokens[9].type == 'COMMA' + assert tokens[10].type == 'PARAMETERS' + assert tokens[11].type == 'EQUALS' + # next tokens come separately, not just single JSON + # assert tokens[10].type == 'JSON' + + def test_create_database_ok(self, ): + sql = "CREATE DATABASE db" + ast = parse_sql(sql) + expected_ast = CreateDatabase(name=Identifier('db'), engine=None, parameters=None) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + sql = "CREATE DATABASE db ENGINE 'eng'" + ast = parse_sql(sql) + expected_ast = CreateDatabase(name=Identifier('db'), engine='eng', parameters=None) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # variants with or without ',' and '=' + for with_ in ('WITH', ''): + engines = [ + "ENGINE 'mysql'", + "ENGINE = 'mysql'", + "ENGINE 'mysql',", + "ENGINE = 'mysql',", + ] + for engine in engines: + for equal in ('=', ''): + sql = """ + CREATE DATABASE db + %(with)s %(engine)s + PARAMETERS %(equal)s {"user": "admin", "password": "admin123_.,';:!@#$%%^&*(){}[]", "host": "127.0.0.1"} + """ % {'equal': equal, 'engine': engine, 'with': with_} + + ast = parse_sql(sql) + expected_ast = CreateDatabase(name=Identifier('db'), + engine='mysql', + parameters=dict(user='admin', password="admin123_.,';:!@#$%^&*(){}[]", host='127.0.0.1')) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + sql = """ + CREATE or REPLACE DATABASE db + /* + multiline comment + */ + WITH ENGINE='mysql' + PARAMETERS = {"user": "admin", "password": "admin123_.,';:!@#$%^&*(){}[]", "host": "127.0.0.1"} + """ + + ast = parse_sql(sql) + expected_ast = CreateDatabase(name=Identifier('db'), + engine='mysql', + is_replace=True, + parameters=dict(user='admin', password="admin123_.,';:!@#$%^&*(){}[]", + host='127.0.0.1')) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # test with if not exists + sql = """ + CREATE DATABASE IF NOT EXISTS db + WITH ENGINE='mysql' + """ + ast = parse_sql(sql) + expected_ast = CreateDatabase(name=Identifier('db'), + engine='mysql', + if_not_exists=True, + parameters=None) + assert str(ast) == str(expected_ast) + + def test_create_database_invalid_json(self): + sql = "CREATE DATABASE db WITH ENGINE = 'mysql', PARAMETERS = 'wow'" + with pytest.raises(ParsingException): + ast = parse_sql(sql) + + def test_create_project(self): + + sql = "create PROJECT dbname" + ast = parse_sql(sql) + + expected_ast = CreateDatabase(name=Identifier('dbname'), engine=None, parameters=None) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + # test with if not exists + sql = """ + CREATE PROJECT IF NOT EXISTS db + """ + ast = parse_sql(sql) + expected_ast = CreateDatabase(name=Identifier('db'), + engine=None, + if_not_exists=True, + parameters=None) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + + def test_create_database_using(self): + + sql = "CREATE DATABASE db using ENGINE = 'mysql', PARAMETERS = {'A': 1}" + ast = parse_sql(sql) + + expected_ast = CreateDatabase(name=Identifier('db'), engine='mysql', parameters={'A': 1}) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + diff --git a/tests/test_mindsdb/test_create_predictor.py b/tests/test_mindsdb/test_create_predictor.py new file mode 100644 index 0000000..09c413e --- /dev/null +++ b/tests/test_mindsdb/test_create_predictor.py @@ -0,0 +1,240 @@ +import pytest + +from mindsdb_sql_parser import parse_sql, ParsingException +from mindsdb_sql_parser.ast.mindsdb import * +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.utils import to_single_line + + +class TestCreatePredictor: + def test_create_predictor_full(self): + sql = """CREATE predictor pred + FROM integration_name + (selct * FROM not some actually ( {'t': [1,2.1,[], {}, False, true, null]} ) not sql (name)) + PREDICT f1 as f1_alias, f2 + ORDER BY f_order_1 ASC, f_order_2, f_order_3 DESC + GROUP BY f_group_1, f_group_2 + WINDOW 100 + HORIZON 7 + USING + a=null, b=true, c=false, + x.`part 2`.part3=1, + y= "a", + z=0.7, + j={'t': [1,2.1,[], {}, False, true, null]}, + q=Filter(x=null, y=true, z=false, a='c', b=2, j={"ar": [1], 'j': {"d": "d"}}) + """ + ast = parse_sql(sql) + expected_ast = CreatePredictor( + name=Identifier('pred'), + integration_name=Identifier('integration_name'), + query_str="selct * FROM not some actually ( {'t': [1,2.1,[], {}, False, true, null]} ) not sql (name)", + targets=[Identifier('f1', alias=Identifier('f1_alias')), + Identifier('f2')], + order_by=[OrderBy(Identifier('f_order_1'), direction='ASC'), + OrderBy(Identifier('f_order_2'), direction='default'), + OrderBy(Identifier('f_order_3'), direction='DESC'), + ], + group_by=[Identifier('f_group_1'), Identifier('f_group_2')], + window=100, + horizon=7, + using={ + 'a': None, 'b': True, 'c': False, + 'x.part 2.part3': 1, + 'y': "a", + 'z': 0.7, + 'j': {'t': [1,2.1,[], {}, False, True, None]}, + 'q': Object(type='Filter', params={ + 'x': None, 'y': True, 'z': False, + 'a': 'c', + 'b': 2, + 'j': {"ar": [1], 'j': {"d": "d"}} + }) + }, + ) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + # convert to string and parse again + ast2 = parse_sql(str(ast)) + assert ast.to_tree() == ast2.to_tree() + + def test_create_predictor_minimal(self): + sql = """CREATE MODEL IF NOT EXISTS pred + FROM integration_name + (select * FROM table_name) + PREDICT f1 as f1_alias, f2 + """ + ast = parse_sql(sql) + expected_ast = CreatePredictor( + name=Identifier('pred'), + if_not_exists=True, + integration_name=Identifier('integration_name'), + query_str="select * FROM table_name", + targets=[Identifier('f1', alias=Identifier('f1_alias')), + Identifier('f2')], + ) + assert str(ast).lower() == to_single_line(sql.lower()) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + def test_create_predictor_no_with(self): + sql = """CREATE PREDICTOR pred + FROM integration_name + (select * FROM table_name) + PREDICT f1 as f1_alias, f2 + """ + ast = parse_sql(sql) + expected_ast = CreatePredictor( + name=Identifier('pred'), + integration_name=Identifier('integration_name'), + query_str="select * FROM table_name", + targets=[Identifier('f1', alias=Identifier('f1_alias')), + Identifier('f2')], + ) + assert ast.to_tree() == expected_ast.to_tree() + + # test create model + sql = """CREATE model pred + FROM integration_name + (select * FROM table_name) + PREDICT f1 as f1_alias, f2 + """ + ast = parse_sql(sql) + + assert ast.to_tree() == expected_ast.to_tree() + + # test or replace + sql = """CREATE or replace model pred + FROM integration_name + (select * FROM table_name) + PREDICT f1 as f1_alias, f2 + """ + ast = parse_sql(sql) + expected_ast.is_replace = True + + assert ast.to_tree() == expected_ast.to_tree() + + def test_create_predictor_quotes(self): + sql = """CREATE PREDICTOR xxx + FROM `yyy` + (SELECT * FROM zzz) + PREDICT sss + """ + ast = parse_sql(sql) + expected_ast = CreatePredictor( + name=Identifier('xxx'), + integration_name=Identifier('yyy'), + query_str="SELECT * FROM zzz", + targets=[Identifier('sss')], + ) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + # # or replace + # sql = """CREATE or REPLACE PREDICTOR xxx + # FROM `yyy` + # (SELECT * FROM zzz) + # AS x + # PREDICT sss + # """ + # ast = parse_sql(sql) + # expected_ast.is_replace = True + # assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + # assert ast.to_tree() == expected_ast.to_tree() + + def test_create_predictor_invalid_json(self): + sql = """CREATE PREDICTOR pred + FROM integration_name + (select * FROM table) + AS ds_name + PREDICT f1 as f1_alias, f2 + ORDER BY f_order_1 ASC, f_order_2, f_order_3 DESC + GROUP BY f_group_1, f_group_2 + WINDOW 100 + HORIZON 7 + USING 'not_really_json'""" + + with pytest.raises(ParsingException): + parse_sql(sql) + + def test_create_predictor_empty_fields(self): + sql = """CREATE PREDICTOR xxx + PREDICT sss + """ + ast = parse_sql(sql) + expected_ast = CreatePredictor( + name=Identifier('xxx'), + integration_name=None, + query_str=None, + targets=[Identifier('sss')], + ) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + def test_create_anomaly_detection_model(self): + for predict_clause in ["", " PREDICT alert "]: + create_clause = """CREATE ANOMALY DETECTION MODEL alert_model """ + rest_clause = """ + FROM integration_name (select * FROM table) + USING + confidence=0.5 + """ + sql = create_clause + predict_clause + rest_clause + ast = parse_sql(sql) + + expected_ast = CreateAnomalyDetectionModel( + name=Identifier('alert_model'), + task=Identifier('AnomalyDetection'), + integration_name=Identifier('integration_name'), + query_str='select * FROM table', + targets=[Identifier('alert')] if predict_clause else None, + using={ + 'confidence': 0.5 + } + ) + + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + def test_optional_db(self): + sql = "CREATE MODEL xxx from (select 1) PREDICT sss" + ast = parse_sql(sql) + expected_ast = CreatePredictor( + name=Identifier('xxx'), + query_str='select 1', + targets=[Identifier('sss')], + ) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + # retrain + sql = "RETRAIN MODEL xxx from (select 1)" + ast = parse_sql(sql) + expected_ast = RetrainPredictor( + name=Identifier('xxx'), + query_str='select 1', + ) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + sql = "RETRAIN xxx from (select 1)" + ast = parse_sql(sql) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + # finetune + sql = "FINETUNE MODEL xxx from (select 1)" + ast = parse_sql(sql) + expected_ast = FinetunePredictor( + name=Identifier('xxx'), + query_str='select 1', + ) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + sql = "FINETUNE xxx from (select 1)" + ast = parse_sql(sql) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + diff --git a/tests/test_mindsdb/test_create_view.py b/tests/test_mindsdb/test_create_view.py new file mode 100644 index 0000000..b05d26f --- /dev/null +++ b/tests/test_mindsdb/test_create_view.py @@ -0,0 +1,54 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast.mindsdb import * +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.lexer import MindsDBLexer + +class TestCreateView: + def test_create_view_lexer(self): + sql = "CREATE VIEW my_view FROM integration AS ( SELECT * FROM pred )" + tokens = list(MindsDBLexer().tokenize(sql)) + assert tokens[0].value == 'CREATE' + assert tokens[0].type == 'CREATE' + + assert tokens[1].value == 'VIEW' + assert tokens[1].type == 'VIEW' + + def test_create_view_full(self): + sql = "CREATE VIEW IF NOT EXISTS my_view FROM integr AS ( SELECT * FROM pred )" + ast = parse_sql(sql) + expected_ast = CreateView(name=Identifier('my_view'), + if_not_exists=True, + from_table=Identifier('integr'), + query_str="SELECT * FROM pred") + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_create_view_nofrom(self): + sql = "CREATE VIEW my_view ( SELECT * FROM pred )" + ast = parse_sql(sql) + expected_ast = CreateView(name=Identifier('my_view'), + query_str="SELECT * FROM pred") + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # def test_create_dataset_full(self): + # sql = "CREATE DATASET my_view FROM integr AS ( SELECT * FROM pred )" + # ast = parse_sql(sql) + # expected_ast = CreateView(name='my_view', + # from_table=Identifier('integr'), + # query_str="SELECT * FROM pred") + # + # assert str(ast) == str(expected_ast) + # assert ast.to_tree() == expected_ast.to_tree() + + # def test_create_dataset_nofrom(self): + # sql = "CREATE DATASET my_view ( SELECT * FROM pred )" + # ast = parse_sql(sql) + # expected_ast = CreateView(name='my_view', + # query_str="SELECT * FROM pred") + + # assert str(ast) == str(expected_ast) + # assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_mindsdb/test_drop_dataset.py b/tests/test_mindsdb/test_drop_dataset.py new file mode 100644 index 0000000..ad0c605 --- /dev/null +++ b/tests/test_mindsdb/test_drop_dataset.py @@ -0,0 +1,13 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast.mindsdb import * +from mindsdb_sql_parser.ast import * + + +class TestDropDataset: + def test_drop_dataset(self): + sql = "DROP DATASET IF EXISTS dsname" + ast = parse_sql(sql) + expected_ast = DropDataset(name=Identifier('dsname'), if_exists=True) + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_mindsdb/test_drop_datasource.py b/tests/test_mindsdb/test_drop_datasource.py new file mode 100644 index 0000000..23dc665 --- /dev/null +++ b/tests/test_mindsdb/test_drop_datasource.py @@ -0,0 +1,24 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast.mindsdb import * +from mindsdb_sql_parser.ast import * + + +class TestDropDatasource: + def test_drop_datasource(self): + sql = "DROP DATASOURCE IF EXISTS dsname" + ast = parse_sql(sql) + expected_ast = DropDatasource(name=Identifier('dsname'), if_exists=True) + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_drop_project(self): + + sql = "DROP PROJECT dbname" + ast = parse_sql(sql) + + expected_ast = DropDatabase(name=Identifier('dbname'), if_exists=False) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + diff --git a/tests/test_mindsdb/test_drop_predictor.py b/tests/test_mindsdb/test_drop_predictor.py new file mode 100644 index 0000000..62f6058 --- /dev/null +++ b/tests/test_mindsdb/test_drop_predictor.py @@ -0,0 +1,40 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast.mindsdb import * +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.lexer import MindsDBLexer + +class TestDropPredictor: + def test_drop_predictor_lexer(self): + sql = "DROP PREDICTOR mindsdb.pred" + tokens = list(MindsDBLexer().tokenize(sql)) + assert tokens[0].type == 'DROP' + assert tokens[1].type == 'PREDICTOR' + assert tokens[2].type == 'ID' + assert tokens[2].value == 'mindsdb' + assert tokens[3].type == 'DOT' + assert tokens[4].type == 'ID' + assert tokens[4].value == 'pred' + + def test_drop_predictor_ok(self): + sql = "DROP PREDICTOR mindsdb.pred" + ast = parse_sql(sql) + expected_ast = DropPredictor(name=Identifier('mindsdb.pred')) + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_drop_model(self): + sql = "DROP model mindsdb.pred" + ast = parse_sql(sql) + expected_ast = DropPredictor(name=Identifier('mindsdb.pred')) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_drop_predictor_if_exists(self): + sql = "DROP PREDICTOR IF EXISTS mindsdb.pred" + ast = parse_sql(sql) + expected_ast = DropPredictor( + name=Identifier('mindsdb.pred'), + if_exists=True) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_mindsdb/test_evaluate.py b/tests/test_mindsdb/test_evaluate.py new file mode 100644 index 0000000..93953ee --- /dev/null +++ b/tests/test_mindsdb/test_evaluate.py @@ -0,0 +1,39 @@ +import pytest + +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.ast.mindsdb.evaluate import Evaluate +from mindsdb_sql_parser.lexer import MindsDBLexer + + +class TestEvaluate: + def test_evaluate_lexer(self): + sql = "EVALUATE balanced_accuracy_score FROM (SELECT true, pred FROM table_1)" + tokens = list(MindsDBLexer().tokenize(sql)) + assert tokens[0].type == 'EVALUATE' + assert tokens[1].type == 'ID' + assert tokens[1].value == 'balanced_accuracy_score' + + def test_evaluate_full_1(self): + sql = "EVALUATE balanced_accuracy_score FROM (SELECT ground_truth, pred FROM table_1) USING adjusted=1, param2=2;" # noqa + ast = parse_sql(sql) + expected_ast = Evaluate( + name=Identifier('balanced_accuracy_score'), + query_str="SELECT ground_truth, pred FROM table_1", + using={'adjusted': 1, 'param2': 2}, + ) + assert ' '.join(str(ast).split()).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_evaluate_full_2(self): + query_str = """SELECT t.rental_price as ground_truth, m.rental_price as prediction FROM example_db.demo_data.home_rentals as t JOIN mindsdb.home_rentals_model as m limit 100""" # noqa + sql = f"""EVALUATE r2_score FROM ({query_str});""" + ast = parse_sql(sql) + expected_ast = Evaluate( + name=Identifier('r2_score'), + query_str=query_str, + ) + assert ' '.join(str(ast).split()).lower() == sql.lower() + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast).lower() == str(expected_ast).lower() diff --git a/tests/test_mindsdb/test_finetune_predictor.py b/tests/test_mindsdb/test_finetune_predictor.py new file mode 100644 index 0000000..445cfe6 --- /dev/null +++ b/tests/test_mindsdb/test_finetune_predictor.py @@ -0,0 +1,36 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.ast.mindsdb.finetune_predictor import FinetunePredictor +from mindsdb_sql_parser.lexer import MindsDBLexer + + +class TestFinetunePredictor: + def test_finetune_predictor_lexer(self): + sql = "FINETUNE mindsdb.pred FROM integration_name (SELECT * FROM table_1) USING a=1" + tokens = list(MindsDBLexer().tokenize(sql)) + assert tokens[0].type == 'FINETUNE' + assert tokens[1].type == 'ID' + assert tokens[1].value == 'mindsdb' + assert tokens[2].type == 'DOT' + assert tokens[3].type == 'ID' + assert tokens[3].value == 'pred' + + def test_finetune_predictor_full(self): + sql = "FINETUNE mindsdb.pred FROM integration_name (SELECT * FROM table_1) USING a=1, b=null" + ast = parse_sql(sql) + expected_ast = FinetunePredictor( + name=Identifier('mindsdb.pred'), + integration_name=Identifier('integration_name'), + query_str="SELECT * FROM table_1", + using={'a': 1, 'b': None}, + ) + assert ' '.join(str(ast).split()).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # with MODEL + sql = "FINETUNE MODEL mindsdb.pred FROM integration_name (SELECT * FROM table_1) USING a=1, b=null" + ast = parse_sql(sql) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_mindsdb/test_jobs.py b/tests/test_mindsdb/test_jobs.py new file mode 100644 index 0000000..b8acf7a --- /dev/null +++ b/tests/test_mindsdb/test_jobs.py @@ -0,0 +1,112 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast.mindsdb import * +from mindsdb_sql_parser.ast import * + + +class TestJobs: + def test_create_job(self): + sql = ''' + create job proj2.j1 ( + select * from pg.tbl1 where b>{{PREVIOUS_START_DATE}} + ) + start now + end '2024-01-01' + every hour + ''' + ast = parse_sql(sql) + expected_ast = CreateJob( + name=Identifier('proj2.j1'), + query_str="select * from pg.tbl1 where b>{{PREVIOUS_START_DATE}}", + start_str="now", + end_str="2024-01-01", + repeat_str="hour" + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # 2 + + sql = ''' + create job j1 as ( + retrain p1; retrain p2 + ) + every '2 hours' + ''' + ast = parse_sql(sql) + expected_ast = CreateJob( + name=Identifier('j1'), + query_str="retrain p1; retrain p2", + repeat_str="2 hours" + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # 3 + + sql = ''' + create job j1 ( + retrain p1; retrain p2 + ) + ''' + ast = parse_sql(sql) + expected_ast = CreateJob( + name=Identifier('j1'), + query_str="retrain p1; retrain p2", + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # 4 + + sql = ''' + create job j1 ( retrain p1 ) + every 2 hours + if (select a from table) + ''' + ast = parse_sql(sql) + expected_ast = CreateJob( + name=Identifier('j1'), + query_str="retrain p1", + if_query_str="select a from table", + repeat_str="2 hours" + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_create_job_minimal_with_if_not_exists(self): + sql = ''' + create job if not exists proj2.j1 ( + select * from pg.tbl1 where b>{{PREVIOUS_START_DATE}} + ) + ''' + ast = parse_sql(sql) + expected_ast = CreateJob( + name=Identifier('proj2.j1'), + query_str="select * from pg.tbl1 where b>{{PREVIOUS_START_DATE}}", + if_not_exists=True + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_drop_job(self): + sql = ''' + drop job proj1.j1 + ''' + ast = parse_sql(sql) + expected_ast = DropJob( + name=Identifier('proj1.j1'), + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # test with if exists + sql = ''' + drop job if exists proj1.j1 + ''' + ast = parse_sql(sql) + expected_ast = DropJob( + name=Identifier('proj1.j1'), + if_exists=True + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() \ No newline at end of file diff --git a/tests/test_mindsdb/test_knowledgebase.py b/tests/test_mindsdb/test_knowledgebase.py new file mode 100644 index 0000000..e2387cb --- /dev/null +++ b/tests/test_mindsdb/test_knowledgebase.py @@ -0,0 +1,372 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast.mindsdb.knowledge_base import ( + CreateKnowledgeBase, + DropKnowledgeBase, +) +from mindsdb_sql_parser.ast import ( + Select, + Identifier, + Join, + Show, + BinaryOperation, + Constant, + Star, + Delete, + Insert, + OrderBy, +) + + +class TestKB: + + def test_create_knowledge_base(self): + # create without select + sql = """ + CREATE KNOWLEDGE_BASE my_knowledge_base + USING + MODEL=mindsdb.my_embedding_model, + STORAGE = my_vector_database.some_table + """ + ast = parse_sql(sql) + expected_ast = CreateKnowledgeBase( + name=Identifier("my_knowledge_base"), + if_not_exists=False, + model=Identifier(parts=["mindsdb", "my_embedding_model"]), + storage=Identifier(parts=["my_vector_database", "some_table"]), + from_select=None, + params={}, + ) + assert ast == expected_ast + + # using the alias KNOWLEDGE BASE without underscore shall also work + sql = """ + CREATE KNOWLEDGE BASE my_knowledge_base + USING + MODEL=mindsdb.my_embedding_model, + STORAGE = my_vector_database.some_table + """ + ast = parse_sql(sql) + assert ast == expected_ast + + # the order of MODEL and STORAGE should not matter + sql = """ + CREATE KNOWLEDGE_BASE my_knowledge_base + USING + STORAGE = my_vector_database.some_table, + MODEL = mindsdb.my_embedding_model + """ + ast = parse_sql(sql) + assert ast == expected_ast + + # create from a query + sql = """ + CREATE KNOWLEDGE_BASE my_knowledge_base + FROM ( + SELECT id, content, embeddings, metadata + FROM my_table + JOIN my_embedding_model + ) + USING + MODEL = mindsdb.my_embedding_model, + STORAGE = my_vector_database.some_table + """ + ast = parse_sql(sql) + expected_ast = CreateKnowledgeBase( + name=Identifier("my_knowledge_base"), + if_not_exists=False, + model=Identifier(parts=["mindsdb", "my_embedding_model"]), + storage=Identifier(parts=["my_vector_database", "some_table"]), + from_select=Select( + targets=[ + Identifier("id"), + Identifier("content"), + Identifier("embeddings"), + Identifier("metadata"), + ], + from_table=Join( + left=Identifier("my_table"), + right=Identifier("my_embedding_model"), + join_type="JOIN", + ), + ), + params={}, + ) + + assert ast == expected_ast + + # create without MODEL + sql = """ + CREATE KNOWLEDGE_BASE my_knowledge_base + USING + STORAGE = my_vector_database.some_table + """ + + expected_ast = CreateKnowledgeBase( + name=Identifier("my_knowledge_base"), + if_not_exists=False, + model=None, + storage=Identifier(parts=["my_vector_database", "some_table"]), + from_select=None, + params={}, + ) + + ast = parse_sql(sql) + + assert ast == expected_ast + + # create without STORAGE + sql = """ + CREATE KNOWLEDGE_BASE my_knowledge_base + USING + MODEL = mindsdb.my_embedding_model + """ + + expected_ast = CreateKnowledgeBase( + name=Identifier("my_knowledge_base"), + if_not_exists=False, + model=Identifier(parts=["mindsdb", "my_embedding_model"]), + from_select=None, + params={}, + ) + + ast = parse_sql(sql) + + assert ast == expected_ast + + # create if not exists + sql = """ + CREATE KNOWLEDGE_BASE IF NOT EXISTS my_knowledge_base + USING + MODEL = mindsdb.my_embedding_model, + STORAGE = my_vector_database.some_table + """ + ast = parse_sql(sql) + expected_ast = CreateKnowledgeBase( + name=Identifier("my_knowledge_base"), + if_not_exists=True, + model=Identifier(parts=["mindsdb", "my_embedding_model"]), + storage=Identifier(parts=["my_vector_database", "some_table"]), + from_select=None, + params={}, + ) + assert ast == expected_ast + + # create without USING ie no storage or model + + sql = """ + CREATE KNOWLEDGE_BASE my_knowledge_base; + """ + ast = parse_sql(sql) + expected_ast = CreateKnowledgeBase( + name=Identifier("my_knowledge_base"), + if_not_exists=False, + model=None, + storage=None, + from_select=None, + params={}, + ) + assert ast == expected_ast + + # create with params + sql = """ + CREATE KNOWLEDGE_BASE my_knowledge_base + USING + MODEL = mindsdb.my_embedding_model, + STORAGE = my_vector_database.some_table, + some_param = 'some value', + other_param = 'other value' + """ + ast = parse_sql(sql) + expected_ast = CreateKnowledgeBase( + name=Identifier("my_knowledge_base"), + if_not_exists=False, + model=Identifier(parts=["mindsdb", "my_embedding_model"]), + storage=Identifier(parts=["my_vector_database", "some_table"]), + from_select=None, + params={"some_param": "some value", "other_param": "other value"}, + ) + assert ast == expected_ast + + def test_drop_knowledge_base(self): + # drop if exists + sql = """ + DROP KNOWLEDGE_BASE IF EXISTS my_knowledge_base + """ + ast = parse_sql(sql) + expected_ast = DropKnowledgeBase( + name=Identifier("my_knowledge_base"), if_exists=True + ) + assert ast == expected_ast + + # drop without if exists + sql = """ + DROP KNOWLEDGE_BASE my_knowledge_base + """ + ast = parse_sql(sql) + + expected_ast = DropKnowledgeBase( + name=Identifier("my_knowledge_base"), if_exists=False + ) + assert ast == expected_ast + + + def test_show_knowledge_base(self): + sql = """ + SHOW KNOWLEDGE_BASES + """ + ast = parse_sql(sql) + expected_ast = Show( + category="KNOWLEDGE_BASES", + ) + assert ast == expected_ast + + # without underscore shall also work + sql = """ + SHOW KNOWLEDGE BASES + """ + ast = parse_sql(sql) + expected_ast = Show( + category="KNOWLEDGE BASES", + ) + assert ast == expected_ast + + def test_select_from_knowledge_base(self): + # this is no different from a regular select + sql = """ + SELECT * FROM my_knowledge_base + WHERE + query = 'some text in natural query' + AND + metadata.some_column = 'some value' + ORDER BY + distances DESC + LIMIT 10 + """ + ast = parse_sql(sql) + + expected_ast = Select( + targets=[Star()], + from_table=Identifier("my_knowledge_base"), + where=BinaryOperation( + op="AND", + args=[ + BinaryOperation( + op="=", + args=[Identifier("query"), Constant("some text in natural query")], + ), + BinaryOperation( + op="=", + args=[Identifier("metadata.some_column"), Constant("some value")], + ), + ], + ), + order_by=[OrderBy(field=Identifier("distances"), direction="DESC")], + limit=Constant(10), + ) + assert ast == expected_ast + + + def test_delete_from_knowledge_base(self): + # this is no different from a regular delete + sql = """ + DELETE FROM my_knowledge_base + WHERE + id = 'some id' + AND + metadata.some_column = 'some value' + """ + ast = parse_sql(sql) + expected_ast = Delete( + table=Identifier("my_knowledge_base"), + where=BinaryOperation( + op="AND", + args=[ + BinaryOperation(op="=", args=[Identifier("id"), Constant("some id")]), + BinaryOperation( + op="=", + args=[Identifier("metadata.some_column"), Constant("some value")], + ), + ], + ), + ) + assert ast == expected_ast + + def test_insert_into_knowledge_base(self): + # this is no different from a regular insert + sql = """ + INSERT INTO my_knowledge_base ( + id, content, embeddings, metadata + ) + VALUES ( + 'some id', + 'some text', + '[1,2,3,4,5]', + '{"some_column": "some value"}' + ), + ( + 'some other id', + 'some other text', + '[1,2,3,4,5]', + '{"some_column": "some value"}' + ) + """ + ast = parse_sql(sql) + expected_ast = Insert( + table=Identifier("my_knowledge_base"), + columns=[ + Identifier("id"), + Identifier("content"), + Identifier("embeddings"), + Identifier("metadata"), + ], + values=[ + [ + Constant("some id"), + Constant("some text"), + Constant("[1,2,3,4,5]"), + Constant('{"some_column": "some value"}'), + ], + [ + Constant("some other id"), + Constant("some other text"), + Constant("[1,2,3,4,5]"), + Constant('{"some_column": "some value"}'), + ], + ], + ) + assert ast == expected_ast + + # insert from a select + sql = """ + INSERT INTO my_knowledge_base ( + id, content, embeddings, metadata + ) + SELECT id, content, embeddings, metadata + FROM my_table + WHERE + metadata.some_column = 'some value' + """ + ast = parse_sql(sql) + expected_ast = Insert( + table=Identifier("my_knowledge_base"), + columns=[ + Identifier("id"), + Identifier("content"), + Identifier("embeddings"), + Identifier("metadata"), + ], + from_select=Select( + targets=[ + Identifier("id"), + Identifier("content"), + Identifier("embeddings"), + Identifier("metadata"), + ], + from_table=Identifier("my_table"), + where=BinaryOperation( + op="=", + args=[Identifier("metadata.some_column"), Constant("some value")], + ), + ), + ) + assert ast == expected_ast diff --git a/tests/test_mindsdb/test_ml_engine.py b/tests/test_mindsdb/test_ml_engine.py new file mode 100644 index 0000000..08c88fc --- /dev/null +++ b/tests/test_mindsdb/test_ml_engine.py @@ -0,0 +1,73 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast.mindsdb import * +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.utils import to_single_line + + +class TestCreateMLEngine: + def test_create_predictor_full(self): + sql = """ + CREATE ML_ENGINE name FROM ml_handler_name USING a=2, f=3 + """ + ast = parse_sql(sql) + expected_ast = CreateMLEngine( + name=Identifier('name'), + handler='ml_handler_name', + params={ + 'a': 2, + 'f': 3 + } + ) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + sql = """ + CREATE ML_ENGINE name FROM ml_handler_name + """ + ast = parse_sql(sql) + expected_ast = CreateMLEngine( + name=Identifier('name'), + handler='ml_handler_name', + params=None + ) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + # test if not exists + sql = """ + CREATE ML_ENGINE IF NOT EXISTS name FROM ml_handler_name + """ + ast = parse_sql(sql) + expected_ast = CreateMLEngine( + name=Identifier('name'), + handler='ml_handler_name', + params=None, + if_not_exists=True + ) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + +class TestDropMLEngine: + def test_create_predictor_full(self): + sql = """ + DROP ML_ENGINE name + """ + ast = parse_sql(sql) + expected_ast = DropMLEngine( + name=Identifier('name'), + ) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() + + # test if exists + sql = """ + DROP ML_ENGINE IF EXISTS name + """ + ast = parse_sql(sql) + expected_ast = DropMLEngine( + name=Identifier('name'), + if_exists=True + ) + assert to_single_line(str(ast)) == to_single_line(str(expected_ast)) + assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_mindsdb/test_retrain_predictor.py b/tests/test_mindsdb/test_retrain_predictor.py new file mode 100644 index 0000000..0c9089e --- /dev/null +++ b/tests/test_mindsdb/test_retrain_predictor.py @@ -0,0 +1,60 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.ast.mindsdb.retrain_predictor import RetrainPredictor +from mindsdb_sql_parser.lexer import MindsDBLexer + + +class TestRetrainPredictor: + def test_retrain_predictor_lexer(self): + sql = "RETRAIN mindsdb.pred" + tokens = list(MindsDBLexer().tokenize(sql)) + assert tokens[0].type == 'RETRAIN' + assert tokens[1].type == 'ID' + assert tokens[1].value == 'mindsdb' + assert tokens[2].type == 'DOT' + assert tokens[2].value == '.' + assert tokens[3].type == 'ID' + assert tokens[3].value == 'pred' + + def test_retrain_predictor_ok(self): + sql = "RETRAIN mindsdb.pred" + ast = parse_sql(sql) + expected_ast = RetrainPredictor(name=Identifier('mindsdb.pred')) + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # with model + sql = "RETRAIN MODEL mindsdb.pred" + ast = parse_sql(sql) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_retrain_predictor_full(self): + sql = """Retrain pred + FROM integration_name + (selct * FROM aaa) + PREDICT f1 + ORDER BY f_order_1 ASC, f_order_2 + GROUP BY f_group_1 + WINDOW 100 + HORIZON 7 + USING + a=null, + b=1 + """ + ast = parse_sql(sql) + expected_ast = RetrainPredictor( + name=Identifier('pred'), + integration_name=Identifier('integration_name'), + query_str="selct * FROM aaa", + targets=[Identifier('f1')], + order_by=[OrderBy(Identifier('f_order_1'), direction='ASC'), + OrderBy(Identifier('f_order_2'), direction='default')], + group_by=[Identifier('f_group_1')], + window=100, + horizon=7, + using={'a': None, 'b': 1}, + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_mindsdb/test_selects.py b/tests/test_mindsdb/test_selects.py new file mode 100644 index 0000000..52e464d --- /dev/null +++ b/tests/test_mindsdb/test_selects.py @@ -0,0 +1,205 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.utils import JoinType + + +class TestSpecificSelects: + def test_select_from_predictors(self): + sql = "SELECT * FROM predictors WHERE name = 'pred_name'" + ast = parse_sql(sql) + expected_ast = Select( + targets=[Star()], + from_table=Identifier('predictors'), + where=BinaryOperation('=', args=[Identifier('name'), Constant('pred_name')]) + ) + + # assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_select_predict_column(self): + sql = "SELECT predict FROM mindsdb.predictors" + ast = parse_sql(sql) + expected_ast = Select( + targets=[Identifier('predict')], + from_table=Identifier('mindsdb.predictors'), + ) + + # assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_select_status_column(self): + sql = "SELECT status FROM mindsdb.predictors" + ast = parse_sql(sql) + expected_ast = Select( + targets=[Identifier('status')], + from_table=Identifier('mindsdb.predictors'), + ) + + # assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_native_query(self): + sql = """ + SELECT status + FROM int1 (select q from p from r) + group by 1 + limit 1 + """ + ast = parse_sql(sql) + expected_ast = Select( + targets=[Identifier('status')], + from_table=NativeQuery( + integration=Identifier('int1'), + query='select q from p from r' + ), + limit=Constant(1), + group_by=[Constant(1)] + ) + + # assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_select_using(self): + sql = """ + SELECT status FROM tbl1 + group by 1 + using p1=1, p2='2' + """ + ast = parse_sql(sql) + expected_ast = Select( + targets=[Identifier('status')], + from_table=Identifier('tbl1'), + group_by=[Constant(1)], + using={ + 'p1': 1, + 'p2': '2' + } + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + + def test_join_using(self): + sql = """ + SELECT status FROM tbl1 + join pred1 + using p1=1, p2='2' + """ + ast = parse_sql(sql) + expected_ast = Select( + targets=[Identifier('status')], + from_table=Join( + left=Identifier('tbl1'), + right=Identifier('pred1'), + join_type=JoinType.JOIN + ), + using={ + 'p1': 1, + 'p2': '2' + } + ) + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_select_limit_negative(self): + sql = """SELECT * FROM t1 LIMIT -1""" + + ast = parse_sql(sql) + expected_ast = Select(targets=[Star()], + from_table=Identifier(parts=['t1']), + limit=Constant(-1)) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_last(self): + sql = """SELECT * FROM t1 t where t.id>last and t.x > coalence(last, 0)""" + + ast = parse_sql(sql) + expected_ast = Select( + targets=[Star()], + from_table=Identifier(parts=['t1'], alias=Identifier('t')), + where=BinaryOperation(op='and', args=[ + BinaryOperation( + op='>', + args=[ + Identifier(parts=['t', 'id']), + Last() + ] + ), + BinaryOperation( + op='>', + args=[ + Identifier(parts=['t', 'x']), + Function(op='coalence', args=[ + Last(), + Constant(0) + ]) + ] + ), + ]) + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + + sql = """SELECT last(a) FROM t1""" + + ast = parse_sql(sql) + expected_ast = Select( + targets=[Function( + op='last', + args=[Identifier('a')] + )], + from_table=Identifier(parts=['t1']), + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_json(self): + sql = """SELECT col->1->'c' from TAB1""" + + ast = parse_sql(sql) + expected_ast = Select( + targets=[BinaryOperation( + op='->', + args=[ + BinaryOperation( + op='->', + args=[ + Identifier('col'), + Constant(1) + ] + ), + Constant('c') + ] + )], + from_table=Identifier(parts=['TAB1']), + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + + def test_match(self): + sql = "SELECT a~b, a!~c from TAB1" + + ast = parse_sql(sql) + expected_ast = Select( + targets=[ + BinaryOperation(op='~', args=[Identifier('a'), Identifier('b')]), + BinaryOperation(op='!~', args=[Identifier('a'), Identifier('c')]), + ], + from_table=Identifier(parts=['TAB1']), + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + diff --git a/tests/test_mindsdb/test_show_mindsdb.py b/tests/test_mindsdb/test_show_mindsdb.py new file mode 100644 index 0000000..cd46bb2 --- /dev/null +++ b/tests/test_mindsdb/test_show_mindsdb.py @@ -0,0 +1,31 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * + + +class TestShowMindsdb: + def test_show_keyword(self): + for keyword in ['STREAMS', + 'PREDICTORS', + 'INTEGRATIONS', + 'DATASOURCES', + 'PUBLICATIONS', + 'DATASETS', + 'ALL']: + sql = f"SHOW {keyword}" + ast = parse_sql(sql) + expected_ast = Show(category=keyword) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_show_tables_arg(self): + for keyword in ['VIEWS', 'TABLES']: + sql = f"SHOW {keyword} from integration_name" + ast = parse_sql(sql) + expected_ast = Show(category=keyword, from_table=Identifier('integration_name')) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + diff --git a/tests/test_mindsdb/test_skills.py b/tests/test_mindsdb/test_skills.py new file mode 100644 index 0000000..c82e7b2 --- /dev/null +++ b/tests/test_mindsdb/test_skills.py @@ -0,0 +1,62 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast.mindsdb import * +from mindsdb_sql_parser.ast import * + + +class TestSkills: + def test_create_skill(self): + sql = ''' + create skill if not exists my_skill + using + type = 'knowledge_base', + source ='my_knowledge_base' + ''' + ast = parse_sql(sql) + expected_ast = CreateSkill( + name=Identifier('my_skill'), + type='knowledge_base', + params={'source': 'my_knowledge_base'}, + if_not_exists=True + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # Parse again after rendering to catch problems with rendering. + ast = parse_sql(str(ast)) + assert str(ast) == str(expected_ast) + + def test_update_skill(self): + sql = ''' + update skill my_skill + set + source = 'new_source' + ''' + ast = parse_sql(sql) + expected_params = { + 'source': 'new_source' + } + expected_ast = UpdateSkill( + name=Identifier('my_skill'), + updated_params=expected_params + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # Parse again after rendering to catch problems with rendering. + ast = parse_sql(str(ast)) + assert str(ast) == str(expected_ast) + + def test_drop_skill(self): + sql = ''' + drop skill if exists my_skill + ''' + ast = parse_sql(sql) + expected_ast = DropSkill( + name=Identifier('my_skill'), if_exists=True + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # Parse again after rendering to catch problems with rendering. + ast = parse_sql(str(ast)) + assert str(ast) == str(expected_ast) diff --git a/tests/test_mindsdb/test_timeseries.py b/tests/test_mindsdb/test_timeseries.py new file mode 100644 index 0000000..986dab6 --- /dev/null +++ b/tests/test_mindsdb/test_timeseries.py @@ -0,0 +1,21 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.ast.mindsdb.latest import Latest +from mindsdb_sql_parser.utils import JoinType + + +class TestTimeSeries: + def test_latest_in_where(self): + sql = "SELECT time, price FROM crypto INNER JOIN pred WHERE time > LATEST" + ast = parse_sql(sql) + expected_ast = Select( + targets=[Identifier('time'), Identifier('price')], + from_table=Join(left=Identifier('crypto'), + right=Identifier('pred'), + join_type=JoinType.INNER_JOIN), + where=BinaryOperation('>', args=[Identifier('time'), Latest()]), + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_mindsdb/test_triggers.py b/tests/test_mindsdb/test_triggers.py new file mode 100644 index 0000000..f24a52c --- /dev/null +++ b/tests/test_mindsdb/test_triggers.py @@ -0,0 +1,68 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast.mindsdb import * +from mindsdb_sql_parser.ast import * + + +class TestTriggers: + def test_create_trigger(self): + sql = ''' + create trigger proj2.tname on db1.tbl1 + ( + retrain p1 + ) + ''' + ast = parse_sql(sql) + expected_ast = CreateTrigger( + name=Identifier('proj2.tname'), + table=Identifier('db1.tbl1'), + query_str="retrain p1", + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_create_trigger_cols(self): + sql = ''' + create trigger proj2.tname on db1.tbl1 + columns aaa + ( + retrain p1 + ) + ''' + ast = parse_sql(sql) + expected_ast = CreateTrigger( + name=Identifier('proj2.tname'), + table=Identifier('db1.tbl1'), + columns=[Identifier('aaa')], + query_str="retrain p1", + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + # 2 columns + sql = ''' + create trigger proj2.tname on db1.tbl1 + columns aaa, bbb + ( + retrain p1 + ) + ''' + ast = parse_sql(sql) + expected_ast = CreateTrigger( + name=Identifier('proj2.tname'), + table=Identifier('db1.tbl1'), + columns=[Identifier('aaa'), Identifier('bbb')], + query_str="retrain p1", + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_drop_trigger(self): + sql = ''' + drop trigger proj1.tname + ''' + ast = parse_sql(sql) + expected_ast = DropTrigger( + name=Identifier('proj1.tname'), + ) + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_mindsdb/test_variables.py b/tests/test_mindsdb/test_variables.py new file mode 100644 index 0000000..ae68a7d --- /dev/null +++ b/tests/test_mindsdb/test_variables.py @@ -0,0 +1,39 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import * +from mindsdb_sql_parser.ast import Variable + +class TestMDBParser: + def test_select_variable(self): + sql = 'SELECT @version' + ast = parse_sql(sql) + expected_ast = Select(targets=[Variable('version')]) + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + + sql = 'SELECT @@version' + ast = parse_sql(sql) + expected_ast = Select(targets=[Variable('version', is_system_var=True)]) + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + + sql = "set autocommit=1, global sql_mode=concat(@@sql_mode, ',STRICT_TRANS_TABLES'), NAMES utf8mb4 COLLATE utf8mb4_unicode_ci" + ast = parse_sql(sql) + expected_ast = Set( + set_list=[ + Set(name=Identifier('autocommit'), value=Constant(1)), + Set(name=Identifier('sql_mode'), + scope='global', + value=Function(op='concat', args=[ + Variable('sql_mode', is_system_var=True), + Constant(',STRICT_TRANS_TABLES') + ]) + ), + Set(category="NAMES", + value=Constant('utf8mb4', with_quotes=False), + params={'COLLATE': Constant('utf8mb4_unicode_ci', with_quotes=False)}) + ] + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + diff --git a/tests/test_mysql/__init__.py b/tests/test_mysql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_mysql/test_mysql_lexer.py b/tests/test_mysql/test_mysql_lexer.py new file mode 100644 index 0000000..4f0c43a --- /dev/null +++ b/tests/test_mysql/test_mysql_lexer.py @@ -0,0 +1,23 @@ +from mindsdb_sql_parser.lexer import MindsDBLexer + + +class TestMySQLLexer: + def test_select_variable(self): + sql = f'SELECT @version' + tokens = list(MindsDBLexer().tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[0].value == 'SELECT' + + assert tokens[1].type == 'VARIABLE' + assert tokens[1].value == 'version' + + def test_select_system_variable(self): + sql = f'SELECT @@version' + tokens = list(MindsDBLexer().tokenize(sql)) + + assert tokens[0].type == 'SELECT' + assert tokens[0].value == 'SELECT' + + assert tokens[1].type == 'SYSTEM_VARIABLE' + assert tokens[1].value == 'version' diff --git a/tests/test_mysql/test_mysql_parser.py b/tests/test_mysql/test_mysql_parser.py new file mode 100644 index 0000000..e79ed0f --- /dev/null +++ b/tests/test_mysql/test_mysql_parser.py @@ -0,0 +1,66 @@ +from mindsdb_sql_parser import parse_sql +from mindsdb_sql_parser.ast import Select, Identifier, BinaryOperation, Star +from mindsdb_sql_parser.ast import Variable +from mindsdb_sql_parser.parser import Show + + +class TestMySQLParser: + def test_select_variable(self): + sql = 'SELECT @version' + ast = parse_sql(sql) + expected_ast = Select(targets=[Variable('version')]) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + + sql = 'SELECT @@version' + ast = parse_sql(sql) + expected_ast = Select(targets=[Variable('version', is_system_var=True)]) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + + def test_select_varialbe_complex(self): + sql = f"""SELECT * FROM tab1 WHERE column1 in (SELECT column2 + @variable FROM t2)""" + ast = parse_sql(sql) + expected_ast = Select(targets=[Star()], + from_table=Identifier('tab1'), + where=BinaryOperation(op='in', + args=( + Identifier('column1'), + Select(targets=[BinaryOperation(op='+', + args=[Identifier('column2'), + Variable('variable')]) + ], + from_table=Identifier('t2'), + parentheses=True) + ) + )) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + + def test_show_index(self): + sql = "SHOW INDEX FROM predictors" + ast = parse_sql(sql) + expected_ast = Show( + category='INDEX', + from_table=Identifier('predictors') + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + + def test_show_index_from_db(self): + sql = "SHOW INDEX FROM predictors FROM db" + ast = parse_sql(sql) + expected_ast = Show( + category='INDEX', + from_table=Identifier('db.predictors'), + ) + + # assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_standard_render.py b/tests/test_standard_render.py new file mode 100644 index 0000000..bdce825 --- /dev/null +++ b/tests/test_standard_render.py @@ -0,0 +1,88 @@ +import inspect +import pkgutil +import sys +import os +import importlib + +from mindsdb_sql_parser import parse_sql, Parameter, Select + + +def load_all_modules_from_dir(dir_names): + for importer, package_name, _ in pkgutil.iter_modules(dir_names): + full_package_name = package_name + if full_package_name not in sys.modules: + spec = importer.find_spec(package_name) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + yield module + + +def check_module(module): + if module.__name__ in ('test_mysql_lexer', 'test_base_lexer'): + # skip + return + + for class_name, klass in inspect.getmembers(module, predicate=inspect.isclass): + if not class_name.startswith('Test'): + continue + + tests = klass() + for test_name, test_method in inspect.getmembers(tests, predicate=inspect.ismethod): + if not test_name.startswith('test_') or test_name.endswith('_error'): + # skip tests that expected error + continue + sig = inspect.signature(test_method) + args = [] + # add dialect + if 'dialect' in sig.parameters: + args.append('mindsdb') + if 'cat' in sig.parameters: + # skip it + continue + + test_method(*args) + + +def parse_sql2(sql): + + params = [] + def check_param_f(node, **kwargs): + if isinstance(node, Parameter): + params.append(node) + + query = parse_sql(sql) + + if isinstance(query, Select) and isinstance(query.from_table, Parameter): + # skip queries with params + return query + + # render + sql2 = query.to_string() + + # Parse again + query2 = parse_sql(sql2) + + # compare result from first and second parsing + assert str(query) == str(query2) + + # return to test: it compares it with expected_ast + return query2 + + +def test_standard_render(): + + base_dir = os.path.dirname(__file__) + dir_names = [ + os.path.join(base_dir, folder) + for folder in os.listdir(base_dir) + if folder.startswith('test_') + ] + + for module in load_all_modules_from_dir(dir_names): + + # inject function + module.parse_sql = parse_sql2 + + check_module(module) + +